diff --git a/.gitignore b/.gitignore index 6b2d6fc97b..8fda9f2980 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,6 @@ cmake_build *.lo *.tar.gz *.log +.coverage +*.pyc +cov_html/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d46ed6070..7f6a3d37f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Please mark all change in change log and use the ticket from JIRA. ## Feature - \#12 - Pure CPU version for Milvus +- \#226 - Experimental shards middleware for Milvus ## Improvement @@ -84,7 +85,7 @@ Please mark all change in change log and use the ticket from JIRA. - MS-658 - Fix SQ8 Hybrid can't search - MS-665 - IVF_SQ8H search crash when no GPU resource in search_resources - \#9 - Change default gpu_cache_capacity to 4 -- \#20 - C++ sdk example get grpc error +- \#20 - C++ sdk example get grpc error - \#23 - Add unittest to improve code coverage - \#31 - make clang-format failed after run build.sh -l - \#39 - Create SQ8H index hang if using github server version @@ -136,7 +137,7 @@ Please mark all change in change log and use the ticket from JIRA. - MS-635 - Add compile option to support customized faiss - MS-660 - add ubuntu_build_deps.sh - \#18 - Add all test cases - + # Milvus 0.4.0 (2019-09-12) ## Bug @@ -345,11 +346,11 @@ Please mark all change in change log and use the ticket from JIRA. - MS-82 - Update server startup welcome message - MS-83 - Update vecwise to Milvus - MS-77 - Performance issue of post-search action -- MS-22 - Enhancement for MemVector size control +- MS-22 - Enhancement for MemVector size control - MS-92 - Unify behavior of debug and release build - MS-98 - Install all unit test to installation directory - MS-115 - Change is_startup of metric_config switch from true to on -- MS-122 - Archive criteria config +- MS-122 - Archive criteria config - MS-124 - HasTable interface - MS-126 - Add more error code - MS-128 - Change default db path diff --git a/shards/.dockerignore b/shards/.dockerignore new file mode 100644 index 0000000000..e450610057 --- /dev/null +++ b/shards/.dockerignore @@ -0,0 +1,13 @@ +.git +.gitignore +.env +.coverage +.dockerignore +cov_html/ + +.pytest_cache +__pycache__ +*/__pycache__ +*.md +*.yml +*.yaml diff --git a/shards/Dockerfile b/shards/Dockerfile new file mode 100644 index 0000000000..594640619e --- /dev/null +++ b/shards/Dockerfile @@ -0,0 +1,10 @@ +FROM python:3.6 +RUN apt update && apt install -y \ + less \ + telnet +RUN mkdir /source +WORKDIR /source +ADD ./requirements.txt ./ +RUN pip install -r requirements.txt +COPY . . +CMD python mishards/main.py diff --git a/shards/Makefile b/shards/Makefile new file mode 100644 index 0000000000..c8aa6127f8 --- /dev/null +++ b/shards/Makefile @@ -0,0 +1,35 @@ +HOST=$(or $(host),127.0.0.1) +PORT=$(or $(port),19530) + +build: + docker build --network=host -t milvusdb/mishards . +push: + docker push milvusdb/mishards +pull: + docker pull milvusdb/mishards +deploy: clean_deploy + cd all_in_one && docker-compose -f all_in_one.yml up -d && cd - +clean_deploy: + cd all_in_one && docker-compose -f all_in_one.yml down && cd - +probe_deploy: + docker run --rm --name probe --net=host milvusdb/mishards /bin/bash -c "python all_in_one/probe_test.py" +cluster: + cd kubernetes_demo;./start.sh baseup;sleep 10;./start.sh appup;cd - +clean_cluster: + cd kubernetes_demo;./start.sh cleanup;cd - +cluster_status: + kubectl get pods -n milvus -o wide +probe_cluster: + @echo + $(shell kubectl get service -n milvus | grep milvus-proxy-servers | awk {'print $$4,$$5'} | awk -F"[: ]" {'print "docker run --rm --name probe --net=host milvusdb/mishards /bin/bash -c \"python all_in_one/probe_test.py --port="$$2" --host="$$1"\""'}) +probe: + docker run --rm --name probe --net=host milvusdb/mishards /bin/bash -c "python all_in_one/probe_test.py --port=${PORT} --host=${HOST}" +clean_coverage: + rm -rf cov_html +clean: clean_coverage clean_deploy clean_cluster +style: + pycodestyle --config=. +coverage: + pytest --cov-report html:cov_html --cov=mishards +test: + pytest diff --git a/shards/Tutorial_CN.md b/shards/Tutorial_CN.md new file mode 100644 index 0000000000..192a0fd285 --- /dev/null +++ b/shards/Tutorial_CN.md @@ -0,0 +1,147 @@ +# Mishards使用文档 +--- +Milvus 旨在帮助用户实现海量非结构化数据的近似检索和分析。单个 Milvus 实例可处理十亿级数据规模,而对于百亿或者千亿规模数据的需求,则需要一个 Milvus 集群实例,该实例对于上层应用可以像单机实例一样使用,同时满足海量数据低延迟,高并发业务需求。mishards就是一个集群中间件,其内部处理请求转发,读写分离,水平扩展,动态扩容,为用户提供内存和算力可以无限扩容的 Milvus 实例。 + +## 运行环境 +--- + +### 单机快速启动实例 +**`python >= 3.4`环境** + +``` +1. cd milvus/shards +2. pip install -r requirements.txt +3. nvidia-docker run --rm -d -p 19530:19530 -v /tmp/milvus/db:/opt/milvus/db milvusdb/milvus:0.5.0-d102119-ede20b +4. sudo chown -R $USER:$USER /tmp/milvus +5. cp mishards/.env.example mishards/.env +6. 在python mishards/main.py #.env配置mishards监听19532端口 +7. make probe port=19532 #健康检查 +``` + +### 容器启动实例 +`all_in_one`会在服务器上开启两个milvus实例,一个mishards实例,一个jaeger链路追踪实例 + +**启动** +``` +cd milvus/shards +1. 安装docker-compose +2. make build +3. make deploy #监听19531端口 +4. make clean_deploy #清理服务 +5. make probe_deplopy #健康检查 +``` + +**打开Jaeger UI** +``` +浏览器打开 "http://127.0.0.1:16686/" +``` + +### kubernetes中快速启动 +**准备** +``` +- kubernetes集群 +- 安装nvidia-docker +- 共享存储 +- 安装kubectl并能访问集群 +``` + +**步骤** +``` +cd milvus/shards +1. make deploy_cluster #启动集群 +2. make probe_cluster #健康检查 +3. make clean_cluster #关闭集群 +``` + +**扩容计算实例** +``` +cd milvus/shards/kubernetes_demo/ +./start.sh scale-ro-server 2 扩容计算实例到2 +``` + +**扩容代理器实例** +``` +cd milvus/shards/kubernetes_demo/ +./start.sh scale-proxy 2 扩容代理服务器实例到2 +``` + +**查看日志** +``` +kubectl logs -f --tail=1000 -n milvus milvus-ro-servers-0 查看计算节点milvus-ro-servers-0日志 +``` + +## 测试 + +**启动单元测试** +``` +1. cd milvus/shards +2. make test +``` + +**单元测试覆盖率** +``` +1. cd milvus/shards +2. make coverage +``` + +**代码风格检查** +``` +1. cd milvus/shards +2. make style +``` + +## mishards配置详解 + +### 全局 +| Name | Required | Type | Default Value | Explanation | +| --------------------------- | -------- | -------- | ------------- | ------------- | +| Debug | No | bool | True | 是否Debug工作模式 | +| TIMEZONE | No | string | "UTC" | 时区 | +| MAX_RETRY | No | int | 3 | 最大连接重试次数 | +| SERVER_PORT | No | int | 19530 | 配置服务端口 | +| WOSERVER | **Yes** | str | - | 配置后台可写Milvus实例地址。目前只支持静态设置,例"tcp://127.0.0.1:19530" | + +### 元数据 +| Name | Required | Type | Default Value | Explanation | +| --------------------------- | -------- | -------- | ------------- | ------------- | +| SQLALCHEMY_DATABASE_URI | **Yes** | string | - | 配置元数据存储数据库地址 | +| SQL_ECHO | No | bool | False | 是否打印Sql详细语句 | +| SQLALCHEMY_DATABASE_TEST_URI | No | string | - | 配置测试环境下元数据存储数据库地址 | +| SQL_TEST_ECHO | No | bool | False | 配置测试环境下是否打印Sql详细语句 | + +### 服务发现 +| Name | Required | Type | Default Value | Explanation | +| --------------------------- | -------- | -------- | ------------- | ------------- | +| DISCOVERY_PLUGIN_PATH | No | string | - | 用户自定义服务发现插件搜索路径,默认使用系统搜索路径| +| DISCOVERY_CLASS_NAME | No | string | static | 在服务发现插件搜索路径下搜索类并实例化。目前系统提供 **static** 和 **kubernetes** 两种类,默认使用 **static** | +| DISCOVERY_STATIC_HOSTS | No | list | [] | **DISCOVERY_CLASS_NAME** 为 **static** 时,配置服务地址列表,例"192.168.1.188,192.168.1.190"| +| DISCOVERY_STATIC_PORT | No | int | 19530 | **DISCOVERY_CLASS_NAME** 为 **static** 时,配置 Hosts 监听端口 | +| DISCOVERY_KUBERNETES_NAMESPACE | No | string | - | **DISCOVERY_CLASS_NAME** 为 **kubernetes** 时,配置集群 namespace | +| DISCOVERY_KUBERNETES_IN_CLUSTER | No | bool | False | **DISCOVERY_CLASS_NAME** 为 **kubernetes** 时,标明服务发现是否在集群中运行 | +| DISCOVERY_KUBERNETES_POLL_INTERVAL | No | int | 5 | **DISCOVERY_CLASS_NAME** 为 **kubernetes** 时,标明服务发现监听服务列表频率,单位 Second | +| DISCOVERY_KUBERNETES_POD_PATT | No | string | - | **DISCOVERY_CLASS_NAME** 为 **kubernetes** 时,匹配可读 Milvus 实例的正则表达式 | +| DISCOVERY_KUBERNETES_LABEL_SELECTOR | No | string | - | **SD_PROVIDER** 为**Kubernetes**时,匹配可读Milvus实例的标签选择 | + +### 链路追踪 +| Name | Required | Type | Default Value | Explanation | +| --------------------------- | -------- | -------- | ------------- | ------------- | +| TRACER_PLUGIN_PATH | No | string | - | 用户自定义链路追踪插件搜索路径,默认使用系统搜索路径| +| TRACER_CLASS_NAME | No | string | "" | 链路追踪方案选择,目前只实现 **Jaeger**, 默认不使用| +| TRACING_SERVICE_NAME | No | string | "mishards" | **TRACING_TYPE** 为 **Jaeger** 时,链路追踪服务名 | +| TRACING_SAMPLER_TYPE | No | string | "const" | **TRACING_TYPE** 为 **Jaeger** 时,链路追踪采样类型 | +| TRACING_SAMPLER_PARAM | No | int | 1 | **TRACING_TYPE** 为 **Jaeger** 时,链路追踪采样频率 | +| TRACING_LOG_PAYLOAD | No | bool | False | **TRACING_TYPE** 为 **Jaeger** 时,链路追踪是否采集 Payload | + +### 日志 +| Name | Required | Type | Default Value | Explanation | +| --------------------------- | -------- | -------- | ------------- | ------------- | +| LOG_LEVEL | No | string | "DEBUG" if Debug is ON else "INFO" | 日志记录级别 | +| LOG_PATH | No | string | "/tmp/mishards" | 日志记录路径 | +| LOG_NAME | No | string | "logfile" | 日志记录名 | + +### 路由 +| Name | Required | Type | Default Value | Explanation | +| --------------------------- | -------- | -------- | ------------- | ------------- | +| ROUTER_PLUGIN_PATH | No | string | - | 用户自定义路由插件搜索路径,默认使用系统搜索路径| +| ROUTER_CLASS_NAME | No | string | FileBasedHashRingRouter | 处理请求路由类名, 可注册自定义类。目前系统只提供了类 **FileBasedHashRingRouter** | +| ROUTER_CLASS_TEST_NAME | No | string | FileBasedHashRingRouter | 测试环境下处理请求路由类名, 可注册自定义类 | diff --git a/shards/all_in_one/all_in_one.yml b/shards/all_in_one/all_in_one.yml new file mode 100644 index 0000000000..40473fe8b9 --- /dev/null +++ b/shards/all_in_one/all_in_one.yml @@ -0,0 +1,53 @@ +version: "2.3" +services: + milvus_wr: + runtime: nvidia + restart: always + image: milvusdb/milvus:0.5.0-d102119-ede20b + volumes: + - /tmp/milvus/db:/opt/milvus/db + + milvus_ro: + runtime: nvidia + restart: always + image: milvusdb/milvus:0.5.0-d102119-ede20b + volumes: + - /tmp/milvus/db:/opt/milvus/db + - ./ro_server.yml:/opt/milvus/conf/server_config.yaml + + jaeger: + restart: always + image: jaegertracing/all-in-one:1.14 + ports: + - "0.0.0.0:5775:5775/udp" + - "0.0.0.0:16686:16686" + - "0.0.0.0:9441:9441" + environment: + COLLECTOR_ZIPKIN_HTTP_PORT: 9411 + + mishards: + restart: always + image: milvusdb/mishards + ports: + - "0.0.0.0:19531:19531" + - "0.0.0.0:19532:19532" + volumes: + - /tmp/milvus/db:/tmp/milvus/db + # - /tmp/mishards_env:/source/mishards/.env + command: ["python", "mishards/main.py"] + environment: + FROM_EXAMPLE: 'true' + DEBUG: 'true' + SERVER_PORT: 19531 + WOSERVER: tcp://milvus_wr:19530 + DISCOVERY_PLUGIN_PATH: static + DISCOVERY_STATIC_HOSTS: milvus_wr,milvus_ro + TRACER_CLASS_NAME: jaeger + TRACING_SERVICE_NAME: mishards-demo + TRACING_REPORTING_HOST: jaeger + TRACING_REPORTING_PORT: 5775 + + depends_on: + - milvus_wr + - milvus_ro + - jaeger diff --git a/shards/all_in_one/probe_test.py b/shards/all_in_one/probe_test.py new file mode 100644 index 0000000000..6250465910 --- /dev/null +++ b/shards/all_in_one/probe_test.py @@ -0,0 +1,25 @@ +from milvus import Milvus + +RED = '\033[0;31m' +GREEN = '\033[0;32m' +ENDC = '' + + +def test(host='127.0.0.1', port=19531): + client = Milvus() + try: + status = client.connect(host=host, port=port) + if status.OK(): + print('{}Pass: Connected{}'.format(GREEN, ENDC)) + return 0 + else: + print('{}Error: {}{}'.format(RED, status, ENDC)) + return 1 + except Exception as exc: + print('{}Error: {}{}'.format(RED, exc, ENDC)) + return 1 + + +if __name__ == '__main__': + import fire + fire.Fire(test) diff --git a/shards/all_in_one/ro_server.yml b/shards/all_in_one/ro_server.yml new file mode 100644 index 0000000000..10cf695448 --- /dev/null +++ b/shards/all_in_one/ro_server.yml @@ -0,0 +1,41 @@ +server_config: + address: 0.0.0.0 # milvus server ip address (IPv4) + port: 19530 # port range: 1025 ~ 65534 + deploy_mode: cluster_readonly # deployment type: single, cluster_readonly, cluster_writable + time_zone: UTC+8 + +db_config: + primary_path: /opt/milvus # path used to store data and meta + secondary_path: # path used to store data only, split by semicolon + + backend_url: sqlite://:@:/ # URI format: dialect://username:password@host:port/database + # Keep 'dialect://:@:/', and replace other texts with real values + # Replace 'dialect' with 'mysql' or 'sqlite' + + insert_buffer_size: 4 # GB, maximum insert buffer size allowed + # sum of insert_buffer_size and cpu_cache_capacity cannot exceed total memory + + preload_table: # preload data at startup, '*' means load all tables, empty value means no preload + # you can specify preload tables like this: table1,table2,table3 + +metric_config: + enable_monitor: false # enable monitoring or not + collector: prometheus # prometheus + prometheus_config: + port: 8080 # port prometheus uses to fetch metrics + +cache_config: + cpu_cache_capacity: 16 # GB, CPU memory used for cache + cpu_cache_threshold: 0.85 # percentage of data that will be kept when cache cleanup is triggered + gpu_cache_capacity: 4 # GB, GPU memory used for cache + gpu_cache_threshold: 0.85 # percentage of data that will be kept when cache cleanup is triggered + cache_insert_data: false # whether to load inserted data into cache + +engine_config: + use_blas_threshold: 20 # if nq < use_blas_threshold, use SSE, faster with fluctuated response times + # if nq >= use_blas_threshold, use OpenBlas, slower with stable response times + +resource_config: + search_resources: # define the GPUs used for search computation, valid value: gpux + - gpu0 + index_build_device: gpu0 # GPU used for building index diff --git a/shards/conftest.py b/shards/conftest.py new file mode 100644 index 0000000000..4cdcbdbe0c --- /dev/null +++ b/shards/conftest.py @@ -0,0 +1,39 @@ +import os +import logging +import pytest +import grpc +import tempfile +import shutil +from mishards import settings, db, create_app + +logger = logging.getLogger(__name__) + +tpath = tempfile.mkdtemp() +dirpath = '{}/db'.format(tpath) +filepath = '{}/meta.sqlite'.format(dirpath) +os.makedirs(dirpath, 0o777) +settings.TestingConfig.SQLALCHEMY_DATABASE_URI = 'sqlite:///{}?check_same_thread=False'.format( + filepath) + + +@pytest.fixture +def app(request): + app = create_app(settings.TestingConfig) + db.drop_all() + db.create_all() + + yield app + + db.drop_all() + app.stop() + # shutil.rmtree(tpath) + + +@pytest.fixture +def started_app(app): + app.on_pre_run() + app.start(settings.SERVER_TEST_PORT) + + yield app + + app.stop() diff --git a/shards/discovery/__init__.py b/shards/discovery/__init__.py new file mode 100644 index 0000000000..a591d1cc1c --- /dev/null +++ b/shards/discovery/__init__.py @@ -0,0 +1,37 @@ +import os +import os +import sys +if __name__ == '__main__': + sys.path.append(os.path.dirname(os.path.dirname( + os.path.abspath(__file__)))) + +import logging +from utils import dotdict + +logger = logging.getLogger(__name__) + + +class DiscoveryConfig(dotdict): + CONFIG_PREFIX = 'DISCOVERY_' + + def dump(self): + logger.info('----------- DiscoveryConfig -----------------') + for k, v in self.items(): + logger.info('{}: {}'.format(k, v)) + if len(self) <= 0: + logger.error(' Empty DiscoveryConfig Found! ') + logger.info('---------------------------------------------') + + @classmethod + def Create(cls, **kwargs): + o = cls() + + for k, v in os.environ.items(): + if not k.startswith(cls.CONFIG_PREFIX): + continue + o[k] = v + for k, v in kwargs.items(): + o[k] = v + + o.dump() + return o diff --git a/shards/discovery/factory.py b/shards/discovery/factory.py new file mode 100644 index 0000000000..5f5c7fcf95 --- /dev/null +++ b/shards/discovery/factory.py @@ -0,0 +1,22 @@ +import logging +from discovery import DiscoveryConfig +from utils.plugins import BaseMixin + +logger = logging.getLogger(__name__) +PLUGIN_PACKAGE_NAME = 'discovery.plugins' + + +class DiscoveryFactory(BaseMixin): + PLUGIN_TYPE = 'Discovery' + + def __init__(self, searchpath=None): + super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME) + + def _create(self, plugin_class, **kwargs): + conn_mgr = kwargs.pop('conn_mgr', None) + if not conn_mgr: + raise RuntimeError('Please pass conn_mgr to create discovery!') + + plugin_config = DiscoveryConfig.Create() + plugin = plugin_class.Create(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs) + return plugin diff --git a/shards/discovery/plugins/kubernetes_provider.py b/shards/discovery/plugins/kubernetes_provider.py new file mode 100644 index 0000000000..aaf6091f83 --- /dev/null +++ b/shards/discovery/plugins/kubernetes_provider.py @@ -0,0 +1,346 @@ +import os +import sys +if __name__ == '__main__': + sys.path.append(os.path.dirname(os.path.dirname( + os.path.abspath(__file__)))) + +import re +import logging +import time +import copy +import threading +import queue +import enum +from kubernetes import client, config, watch + +logger = logging.getLogger(__name__) + +INCLUSTER_NAMESPACE_PATH = '/var/run/secrets/kubernetes.io/serviceaccount/namespace' + + +class EventType(enum.Enum): + PodHeartBeat = 1 + Watch = 2 + + +class K8SMixin: + def __init__(self, namespace, in_cluster=False, **kwargs): + self.namespace = namespace + self.in_cluster = in_cluster + self.kwargs = kwargs + self.v1 = kwargs.get('v1', None) + if not self.namespace: + self.namespace = open(INCLUSTER_NAMESPACE_PATH).read() + + if not self.v1: + config.load_incluster_config( + ) if self.in_cluster else config.load_kube_config() + self.v1 = client.CoreV1Api() + + +class K8SHeartbeatHandler(threading.Thread, K8SMixin): + name = 'kubernetes' + + def __init__(self, + message_queue, + namespace, + label_selector, + in_cluster=False, + **kwargs): + K8SMixin.__init__(self, + namespace=namespace, + in_cluster=in_cluster, + **kwargs) + threading.Thread.__init__(self) + self.queue = message_queue + self.terminate = False + self.label_selector = label_selector + self.poll_interval = kwargs.get('poll_interval', 5) + + def run(self): + while not self.terminate: + try: + pods = self.v1.list_namespaced_pod( + namespace=self.namespace, + label_selector=self.label_selector) + event_message = {'eType': EventType.PodHeartBeat, 'events': []} + for item in pods.items: + pod = self.v1.read_namespaced_pod(name=item.metadata.name, + namespace=self.namespace) + name = pod.metadata.name + ip = pod.status.pod_ip + phase = pod.status.phase + reason = pod.status.reason + message = pod.status.message + ready = True if phase == 'Running' else False + + pod_event = dict(pod=name, + ip=ip, + ready=ready, + reason=reason, + message=message) + + event_message['events'].append(pod_event) + + self.queue.put(event_message) + + except Exception as exc: + logger.error(exc) + + time.sleep(self.poll_interval) + + def stop(self): + self.terminate = True + + +class K8SEventListener(threading.Thread, K8SMixin): + def __init__(self, message_queue, namespace, in_cluster=False, **kwargs): + K8SMixin.__init__(self, + namespace=namespace, + in_cluster=in_cluster, + **kwargs) + threading.Thread.__init__(self) + self.queue = message_queue + self.terminate = False + self.at_start_up = True + self._stop_event = threading.Event() + + def stop(self): + self.terminate = True + self._stop_event.set() + + def run(self): + resource_version = '' + w = watch.Watch() + for event in w.stream(self.v1.list_namespaced_event, + namespace=self.namespace, + field_selector='involvedObject.kind=Pod'): + if self.terminate: + break + + resource_version = int(event['object'].metadata.resource_version) + + info = dict( + eType=EventType.Watch, + pod=event['object'].involved_object.name, + reason=event['object'].reason, + message=event['object'].message, + start_up=self.at_start_up, + ) + self.at_start_up = False + # logger.info('Received event: {}'.format(info)) + self.queue.put(info) + + +class EventHandler(threading.Thread): + def __init__(self, mgr, message_queue, namespace, pod_patt, **kwargs): + threading.Thread.__init__(self) + self.mgr = mgr + self.queue = message_queue + self.kwargs = kwargs + self.terminate = False + self.pod_patt = re.compile(pod_patt) + self.namespace = namespace + + def stop(self): + self.terminate = True + + def on_drop(self, event, **kwargs): + pass + + def on_pod_started(self, event, **kwargs): + try_cnt = 3 + pod = None + while try_cnt > 0: + try_cnt -= 1 + try: + pod = self.mgr.v1.read_namespaced_pod(name=event['pod'], + namespace=self.namespace) + if not pod.status.pod_ip: + time.sleep(0.5) + continue + break + except client.rest.ApiException as exc: + time.sleep(0.5) + + if try_cnt <= 0 and not pod: + if not event['start_up']: + logger.error('Pod {} is started but cannot read pod'.format( + event['pod'])) + return + elif try_cnt <= 0 and not pod.status.pod_ip: + logger.warning('NoPodIPFoundError') + return + + logger.info('Register POD {} with IP {}'.format( + pod.metadata.name, pod.status.pod_ip)) + self.mgr.add_pod(name=pod.metadata.name, ip=pod.status.pod_ip) + + def on_pod_killing(self, event, **kwargs): + logger.info('Unregister POD {}'.format(event['pod'])) + self.mgr.delete_pod(name=event['pod']) + + def on_pod_heartbeat(self, event, **kwargs): + names = self.mgr.conn_mgr.conn_names + + running_names = set() + for each_event in event['events']: + if each_event['ready']: + self.mgr.add_pod(name=each_event['pod'], ip=each_event['ip']) + running_names.add(each_event['pod']) + else: + self.mgr.delete_pod(name=each_event['pod']) + + to_delete = names - running_names + for name in to_delete: + self.mgr.delete_pod(name) + + logger.info(self.mgr.conn_mgr.conn_names) + + def handle_event(self, event): + if event['eType'] == EventType.PodHeartBeat: + return self.on_pod_heartbeat(event) + + if not event or (event['reason'] not in ('Started', 'Killing')): + return self.on_drop(event) + + if not re.match(self.pod_patt, event['pod']): + return self.on_drop(event) + + logger.info('Handling event: {}'.format(event)) + + if event['reason'] == 'Started': + return self.on_pod_started(event) + + return self.on_pod_killing(event) + + def run(self): + while not self.terminate: + try: + event = self.queue.get(timeout=1) + self.handle_event(event) + except queue.Empty: + continue + + +class KubernetesProviderSettings: + def __init__(self, namespace, pod_patt, label_selector, in_cluster, + poll_interval, port=None, **kwargs): + self.namespace = namespace + self.pod_patt = pod_patt + self.label_selector = label_selector + self.in_cluster = in_cluster + self.poll_interval = poll_interval + self.port = int(port) if port else 19530 + + +class KubernetesProvider(object): + name = 'kubernetes' + + def __init__(self, plugin_config, conn_mgr, **kwargs): + self.namespace = plugin_config.DISCOVERY_KUBERNETES_NAMESPACE + self.pod_patt = plugin_config.DISCOVERY_KUBERNETES_POD_PATT + self.label_selector = plugin_config.DISCOVERY_KUBERNETES_LABEL_SELECTOR + self.in_cluster = plugin_config.DISCOVERY_KUBERNETES_IN_CLUSTER.lower() + self.in_cluster = self.in_cluster == 'true' + self.poll_interval = plugin_config.DISCOVERY_KUBERNETES_POLL_INTERVAL + self.poll_interval = int(self.poll_interval) if self.poll_interval else 5 + self.port = plugin_config.DISCOVERY_KUBERNETES_PORT + self.port = int(self.port) if self.port else 19530 + self.kwargs = kwargs + self.queue = queue.Queue() + + self.conn_mgr = conn_mgr + + if not self.namespace: + self.namespace = open(incluster_namespace_path).read() + + config.load_incluster_config( + ) if self.in_cluster else config.load_kube_config() + self.v1 = client.CoreV1Api() + + self.listener = K8SEventListener(message_queue=self.queue, + namespace=self.namespace, + in_cluster=self.in_cluster, + v1=self.v1, + **kwargs) + + self.pod_heartbeater = K8SHeartbeatHandler( + message_queue=self.queue, + namespace=self.namespace, + label_selector=self.label_selector, + in_cluster=self.in_cluster, + v1=self.v1, + poll_interval=self.poll_interval, + **kwargs) + + self.event_handler = EventHandler(mgr=self, + message_queue=self.queue, + namespace=self.namespace, + pod_patt=self.pod_patt, + **kwargs) + + def add_pod(self, name, ip): + self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port)) + + def delete_pod(self, name): + self.conn_mgr.unregister(name) + + def start(self): + self.listener.daemon = True + self.listener.start() + self.event_handler.start() + + self.pod_heartbeater.start() + + def stop(self): + self.listener.stop() + self.pod_heartbeater.stop() + self.event_handler.stop() + + @classmethod + def Create(cls, conn_mgr, plugin_config, **kwargs): + discovery = cls(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs) + return discovery + + +def setup(app): + logger.info('Plugin \'{}\' Installed In Package: {}'.format(__file__, app.plugin_package_name)) + app.on_plugin_setup(KubernetesProvider) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname( + os.path.abspath(__file__)))))) + sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname( + os.path.abspath(__file__))))) + + class Connect: + def register(self, name, value): + logger.error('Register: {} - {}'.format(name, value)) + + def unregister(self, name): + logger.error('Unregister: {}'.format(name)) + + @property + def conn_names(self): + return set() + + connect_mgr = Connect() + + from discovery import DiscoveryConfig + settings = DiscoveryConfig(DISCOVERY_KUBERNETES_NAMESPACE='xp', + DISCOVERY_KUBERNETES_POD_PATT=".*-ro-servers-.*", + DISCOVERY_KUBERNETES_LABEL_SELECTOR='tier=ro-servers', + DISCOVERY_KUBERNETES_POLL_INTERVAL=5, + DISCOVERY_KUBERNETES_IN_CLUSTER=False) + + provider_class = KubernetesProvider + t = provider_class(conn_mgr=connect_mgr, plugin_config=settings) + t.start() + cnt = 100 + while cnt > 0: + time.sleep(2) + cnt -= 1 + t.stop() diff --git a/shards/discovery/plugins/static_provider.py b/shards/discovery/plugins/static_provider.py new file mode 100644 index 0000000000..fca8c717db --- /dev/null +++ b/shards/discovery/plugins/static_provider.py @@ -0,0 +1,45 @@ +import os +import sys +if __name__ == '__main__': + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import logging +import socket +from environs import Env + +logger = logging.getLogger(__name__) +env = Env() + + +class StaticDiscovery(object): + name = 'static' + + def __init__(self, config, conn_mgr, **kwargs): + self.conn_mgr = conn_mgr + hosts = env.list('DISCOVERY_STATIC_HOSTS', []) + self.port = env.int('DISCOVERY_STATIC_PORT', 19530) + self.hosts = [socket.gethostbyname(host) for host in hosts] + + def start(self): + for host in self.hosts: + self.add_pod(host, host) + + def stop(self): + for host in self.hosts: + self.delete_pod(host) + + def add_pod(self, name, ip): + self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port)) + + def delete_pod(self, name): + self.conn_mgr.unregister(name) + + @classmethod + def Create(cls, conn_mgr, plugin_config, **kwargs): + discovery = cls(config=plugin_config, conn_mgr=conn_mgr, **kwargs) + return discovery + + +def setup(app): + logger.info('Plugin \'{}\' Installed In Package: {}'.format(__file__, app.plugin_package_name)) + app.on_plugin_setup(StaticDiscovery) diff --git a/shards/kubernetes_demo/milvus_auxiliary.yaml b/shards/kubernetes_demo/milvus_auxiliary.yaml new file mode 100644 index 0000000000..fff27adc6f --- /dev/null +++ b/shards/kubernetes_demo/milvus_auxiliary.yaml @@ -0,0 +1,67 @@ +kind: Service +apiVersion: v1 +metadata: + name: milvus-mysql + namespace: milvus +spec: + type: ClusterIP + selector: + app: milvus + tier: mysql + ports: + - protocol: TCP + port: 3306 + targetPort: 3306 + name: mysql + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: milvus-mysql + namespace: milvus +spec: + selector: + matchLabels: + app: milvus + tier: mysql + replicas: 1 + template: + metadata: + labels: + app: milvus + tier: mysql + spec: + containers: + - name: milvus-mysql + image: mysql:5.7 + imagePullPolicy: IfNotPresent + # lifecycle: + # postStart: + # exec: + # command: ["/bin/sh", "-c", "mysql -h milvus-mysql -uroot -p${MYSQL_ROOT_PASSWORD} -e \"CREATE DATABASE IF NOT EXISTS ${DATABASE};\"; \ + # mysql -uroot -p${MYSQL_ROOT_PASSWORD} -e \"GRANT ALL PRIVILEGES ON ${DATABASE}.* TO 'root'@'%';\""] + env: + - name: MYSQL_ROOT_PASSWORD + value: milvusroot + - name: DATABASE + value: milvus + ports: + - name: mysql-port + containerPort: 3306 + volumeMounts: + - name: milvus-mysql-disk + mountPath: /data + subPath: mysql + - name: milvus-mysql-configmap + mountPath: /etc/mysql/mysql.conf.d/mysqld.cnf + subPath: milvus_mysql_config.yml + + volumes: + - name: milvus-mysql-disk + persistentVolumeClaim: + claimName: milvus-mysql-disk + - name: milvus-mysql-configmap + configMap: + name: milvus-mysql-configmap diff --git a/shards/kubernetes_demo/milvus_configmap.yaml b/shards/kubernetes_demo/milvus_configmap.yaml new file mode 100644 index 0000000000..cb751c02f1 --- /dev/null +++ b/shards/kubernetes_demo/milvus_configmap.yaml @@ -0,0 +1,185 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: milvus-mysql-configmap + namespace: milvus +data: + milvus_mysql_config.yml: | + [mysqld] + pid-file = /var/run/mysqld/mysqld.pid + socket = /var/run/mysqld/mysqld.sock + datadir = /data + log-error = /var/log/mysql/error.log # mount out to host + # By default we only accept connections from localhost + bind-address = 0.0.0.0 + # Disabling symbolic-links is recommended to prevent assorted security risks + symbolic-links=0 + character-set-server = utf8mb4 + collation-server = utf8mb4_unicode_ci + init_connect='SET NAMES utf8mb4' + skip-character-set-client-handshake = true + max_connections = 1000 + wait_timeout = 31536000 + +--- + +apiVersion: v1 +kind: ConfigMap +metadata: + name: milvus-proxy-configmap + namespace: milvus +data: + milvus_proxy_config.yml: | + DEBUG=True + TESTING=False + + WOSERVER=tcp://milvus-wo-servers:19530 + SERVER_PORT=19530 + + DISCOVERY_CLASS_NAME=kubernetes + DISCOVERY_KUBERNETES_NAMESPACE=milvus + DISCOVERY_KUBERNETES_POD_PATT=.*-ro-servers-.* + DISCOVERY_KUBERNETES_LABEL_SELECTOR=tier=ro-servers + DISCOVERY_KUBERNETES_POLL_INTERVAL=10 + DISCOVERY_KUBERNETES_IN_CLUSTER=True + + SQLALCHEMY_DATABASE_URI=mysql+pymysql://root:milvusroot@milvus-mysql:3306/milvus?charset=utf8mb4 + SQLALCHEMY_POOL_SIZE=50 + SQLALCHEMY_POOL_RECYCLE=7200 + + LOG_PATH=/var/log/milvus + TIMEZONE=Asia/Shanghai +--- + +apiVersion: v1 +kind: ConfigMap +metadata: + name: milvus-roserver-configmap + namespace: milvus +data: + config.yml: | + server_config: + address: 0.0.0.0 + port: 19530 + mode: cluster_readonly + + db_config: + primary_path: /var/milvus + backend_url: mysql://root:milvusroot@milvus-mysql:3306/milvus + insert_buffer_size: 2 + + metric_config: + enable_monitor: off # true is on, false is off + + cache_config: + cpu_cache_capacity: 12 # memory pool to hold index data, unit: GB + cpu_cache_free_percent: 0.85 + insert_cache_immediately: false + # gpu_cache_capacity: 4 + # gpu_cache_free_percent: 0.85 + # gpu_ids: + # - 0 + + engine_config: + use_blas_threshold: 800 + + resource_config: + search_resources: + - gpu0 + + log.conf: | + * GLOBAL: + FORMAT = "%datetime | %level | %logger | %msg" + FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-global.log" + ENABLED = true + TO_FILE = true + TO_STANDARD_OUTPUT = true + SUBSECOND_PRECISION = 3 + PERFORMANCE_TRACKING = false + MAX_LOG_FILE_SIZE = 2097152 ## Throw log files away after 2MB + * DEBUG: + FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-debug.log" + ENABLED = true + * WARNING: + FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-warning.log" + * TRACE: + FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-trace.log" + * VERBOSE: + FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg" + TO_FILE = true + TO_STANDARD_OUTPUT = true + ## Error logs + * ERROR: + ENABLED = true + FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-error.log" + * FATAL: + ENABLED = true + FILENAME = "/var/milvus/logs/milvus-ro-%datetime{%H:%m}-fatal.log" + +--- + +apiVersion: v1 +kind: ConfigMap +metadata: + name: milvus-woserver-configmap + namespace: milvus +data: + config.yml: | + server_config: + address: 0.0.0.0 + port: 19530 + mode: cluster_writable + + db_config: + primary_path: /var/milvus + backend_url: mysql://root:milvusroot@milvus-mysql:3306/milvus + insert_buffer_size: 2 + + metric_config: + enable_monitor: off # true is on, false is off + + cache_config: + cpu_cache_capacity: 2 # memory pool to hold index data, unit: GB + cpu_cache_free_percent: 0.85 + insert_cache_immediately: false + # gpu_cache_capacity: 4 + # gpu_cache_free_percent: 0.85 + # gpu_ids: + # - 0 + + engine_config: + use_blas_threshold: 800 + + resource_config: + search_resources: + - gpu0 + + + log.conf: | + * GLOBAL: + FORMAT = "%datetime | %level | %logger | %msg" + FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-global.log" + ENABLED = true + TO_FILE = true + TO_STANDARD_OUTPUT = true + SUBSECOND_PRECISION = 3 + PERFORMANCE_TRACKING = false + MAX_LOG_FILE_SIZE = 2097152 ## Throw log files away after 2MB + * DEBUG: + FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-debug.log" + ENABLED = true + * WARNING: + FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-warning.log" + * TRACE: + FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-trace.log" + * VERBOSE: + FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg" + TO_FILE = true + TO_STANDARD_OUTPUT = true + ## Error logs + * ERROR: + ENABLED = true + FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-error.log" + * FATAL: + ENABLED = true + FILENAME = "/var/milvus/logs/milvus-wo-%datetime{%H:%m}-fatal.log" diff --git a/shards/kubernetes_demo/milvus_data_pvc.yaml b/shards/kubernetes_demo/milvus_data_pvc.yaml new file mode 100644 index 0000000000..480354507d --- /dev/null +++ b/shards/kubernetes_demo/milvus_data_pvc.yaml @@ -0,0 +1,57 @@ +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: milvus-db-disk + namespace: milvus +spec: + accessModes: + - ReadWriteMany + storageClassName: default + resources: + requests: + storage: 50Gi + +--- + +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: milvus-log-disk + namespace: milvus +spec: + accessModes: + - ReadWriteMany + storageClassName: default + resources: + requests: + storage: 50Gi + +--- + +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: milvus-mysql-disk + namespace: milvus +spec: + accessModes: + - ReadWriteMany + storageClassName: default + resources: + requests: + storage: 50Gi + +--- + +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: milvus-redis-disk + namespace: milvus +spec: + accessModes: + - ReadWriteOnce + storageClassName: default + resources: + requests: + storage: 5Gi diff --git a/shards/kubernetes_demo/milvus_proxy.yaml b/shards/kubernetes_demo/milvus_proxy.yaml new file mode 100644 index 0000000000..13916b7b2b --- /dev/null +++ b/shards/kubernetes_demo/milvus_proxy.yaml @@ -0,0 +1,88 @@ +kind: Service +apiVersion: v1 +metadata: + name: milvus-proxy-servers + namespace: milvus +spec: + type: LoadBalancer + selector: + app: milvus + tier: proxy + ports: + - name: tcp + protocol: TCP + port: 19530 + targetPort: 19530 + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: milvus-proxy + namespace: milvus +spec: + selector: + matchLabels: + app: milvus + tier: proxy + replicas: 1 + template: + metadata: + labels: + app: milvus + tier: proxy + spec: + containers: + - name: milvus-proxy + image: milvusdb/mishards:0.1.0-rc0 + imagePullPolicy: Always + command: ["python", "mishards/main.py"] + resources: + limits: + memory: "3Gi" + cpu: "4" + requests: + memory: "2Gi" + ports: + - name: tcp + containerPort: 5000 + env: + # - name: SQL_ECHO + # value: "True" + - name: DEBUG + value: "False" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: MILVUS_CLIENT + value: "False" + - name: LOG_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: LOG_PATH + value: /var/log/milvus + - name: SD_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + - name: SD_ROSERVER_POD_PATT + value: ".*-ro-servers-.*" + volumeMounts: + - name: milvus-proxy-configmap + mountPath: /source/mishards/.env + subPath: milvus_proxy_config.yml + - name: milvus-log-disk + mountPath: /var/log/milvus + subPath: proxylog + # imagePullSecrets: + # - name: regcred + volumes: + - name: milvus-proxy-configmap + configMap: + name: milvus-proxy-configmap + - name: milvus-log-disk + persistentVolumeClaim: + claimName: milvus-log-disk diff --git a/shards/kubernetes_demo/milvus_rbac.yaml b/shards/kubernetes_demo/milvus_rbac.yaml new file mode 100644 index 0000000000..e6f302be15 --- /dev/null +++ b/shards/kubernetes_demo/milvus_rbac.yaml @@ -0,0 +1,24 @@ +kind: ClusterRole +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: pods-list +rules: +- apiGroups: [""] + resources: ["pods", "events"] + verbs: ["list", "get", "watch"] + +--- + +kind: ClusterRoleBinding +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: pods-list +subjects: +- kind: ServiceAccount + name: default + namespace: milvus +roleRef: + kind: ClusterRole + name: pods-list + apiGroup: rbac.authorization.k8s.io +--- diff --git a/shards/kubernetes_demo/milvus_stateful_servers.yaml b/shards/kubernetes_demo/milvus_stateful_servers.yaml new file mode 100644 index 0000000000..4ff5045599 --- /dev/null +++ b/shards/kubernetes_demo/milvus_stateful_servers.yaml @@ -0,0 +1,68 @@ +kind: Service +apiVersion: v1 +metadata: + name: milvus-ro-servers + namespace: milvus +spec: + type: ClusterIP + selector: + app: milvus + tier: ro-servers + ports: + - protocol: TCP + port: 19530 + targetPort: 19530 + +--- + +apiVersion: apps/v1beta1 +kind: StatefulSet +metadata: + name: milvus-ro-servers + namespace: milvus +spec: + serviceName: "milvus-ro-servers" + replicas: 1 + template: + metadata: + labels: + app: milvus + tier: ro-servers + spec: + terminationGracePeriodSeconds: 11 + containers: + - name: milvus-ro-server + image: milvusdb/milvus:0.5.0-d102119-ede20b + imagePullPolicy: Always + ports: + - containerPort: 19530 + resources: + limits: + memory: "16Gi" + cpu: "8.0" + requests: + memory: "14Gi" + volumeMounts: + - name: milvus-db-disk + mountPath: /var/milvus + subPath: dbdata + - name: milvus-roserver-configmap + mountPath: /opt/milvus/conf/server_config.yaml + subPath: config.yml + - name: milvus-roserver-configmap + mountPath: /opt/milvus/conf/log_config.conf + subPath: log.conf + # imagePullSecrets: + # - name: regcred + # tolerations: + # - key: "worker" + # operator: "Equal" + # value: "performance" + # effect: "NoSchedule" + volumes: + - name: milvus-roserver-configmap + configMap: + name: milvus-roserver-configmap + - name: milvus-db-disk + persistentVolumeClaim: + claimName: milvus-db-disk diff --git a/shards/kubernetes_demo/milvus_write_servers.yaml b/shards/kubernetes_demo/milvus_write_servers.yaml new file mode 100644 index 0000000000..6aec4b0373 --- /dev/null +++ b/shards/kubernetes_demo/milvus_write_servers.yaml @@ -0,0 +1,70 @@ +kind: Service +apiVersion: v1 +metadata: + name: milvus-wo-servers + namespace: milvus +spec: + type: ClusterIP + selector: + app: milvus + tier: wo-servers + ports: + - protocol: TCP + port: 19530 + targetPort: 19530 + +--- + +apiVersion: apps/v1beta1 +kind: Deployment +metadata: + name: milvus-wo-servers + namespace: milvus +spec: + selector: + matchLabels: + app: milvus + tier: wo-servers + replicas: 1 + template: + metadata: + labels: + app: milvus + tier: wo-servers + spec: + containers: + - name: milvus-wo-server + image: milvusdb/milvus:0.5.0-d102119-ede20b + imagePullPolicy: Always + ports: + - containerPort: 19530 + resources: + limits: + memory: "5Gi" + cpu: "1.0" + requests: + memory: "4Gi" + volumeMounts: + - name: milvus-db-disk + mountPath: /var/milvus + subPath: dbdata + - name: milvus-woserver-configmap + mountPath: /opt/milvus/conf/server_config.yaml + subPath: config.yml + - name: milvus-woserver-configmap + mountPath: /opt/milvus/conf/log_config.conf + subPath: log.conf + # imagePullSecrets: + # - name: regcred + # tolerations: + # - key: "worker" + # operator: "Equal" + # value: "performance" + # effect: "NoSchedule" + volumes: + - name: milvus-woserver-configmap + configMap: + name: milvus-woserver-configmap + - name: milvus-db-disk + persistentVolumeClaim: + claimName: milvus-db-disk diff --git a/shards/kubernetes_demo/start.sh b/shards/kubernetes_demo/start.sh new file mode 100755 index 0000000000..7441aa5d70 --- /dev/null +++ b/shards/kubernetes_demo/start.sh @@ -0,0 +1,368 @@ +#!/bin/bash + +UL=`tput smul` +NOUL=`tput rmul` +BOLD=`tput bold` +NORMAL=`tput sgr0` +RED='\033[0;31m' +GREEN='\033[0;32m' +BLUE='\033[0;34m' +YELLOW='\033[1;33m' +ENDC='\033[0m' + +function showHelpMessage () { + echo -e "${BOLD}Usage:${NORMAL} ${RED}$0${ENDC} [option...] {cleanup${GREEN}|${ENDC}baseup${GREEN}|${ENDC}appup${GREEN}|${ENDC}appdown${GREEN}|${ENDC}allup}" >&2 + echo + echo " -h, --help show help message" + echo " ${BOLD}cleanup, delete all resources${NORMAL}" + echo " ${BOLD}baseup, start all required base resources${NORMAL}" + echo " ${BOLD}appup, start all pods${NORMAL}" + echo " ${BOLD}appdown, remove all pods${NORMAL}" + echo " ${BOLD}allup, start all base resources and pods${NORMAL}" + echo " ${BOLD}scale-proxy, scale proxy${NORMAL}" + echo " ${BOLD}scale-ro-server, scale readonly servers${NORMAL}" + echo " ${BOLD}scale-worker, scale calculation workers${NORMAL}" +} + +function showscaleHelpMessage () { + echo -e "${BOLD}Usage:${NORMAL} ${RED}$0 $1${ENDC} [option...] {1|2|3|4|...}" >&2 + echo + echo " -h, --help show help message" + echo " ${BOLD}number, (int) target scale number" +} + +function PrintScaleSuccessMessage() { + echo -e "${BLUE}${BOLD}Successfully Scaled: ${1} --> ${2}${ENDC}" +} + +function PrintPodStatusMessage() { + echo -e "${BOLD}${1}${NORMAL}" +} + +timeout=60 + +function setUpMysql () { + mysqlUserName=$(kubectl describe configmap -n milvus milvus-roserver-configmap | + grep backend_url | + awk '{print $2}' | + awk '{split($0, level1, ":"); + split(level1[2], level2, "/"); + print level2[3]}') + mysqlPassword=$(kubectl describe configmap -n milvus milvus-roserver-configmap | + grep backend_url | + awk '{print $2}' | + awk '{split($0, level1, ":"); + split(level1[3], level3, "@"); + print level3[1]}') + mysqlDBName=$(kubectl describe configmap -n milvus milvus-roserver-configmap | + grep backend_url | + awk '{print $2}' | + awk '{split($0, level1, ":"); + split(level1[4], level4, "/"); + print level4[2]}') + mysqlContainer=$(kubectl get pods -n milvus | grep milvus-mysql | awk '{print $1}') + + kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "CREATE DATABASE IF NOT EXISTS $mysqlDBName;" + + checkDBExists=$(kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "SELECT schema_name FROM information_schema.schemata WHERE schema_name = '$mysqlDBName';" | grep -o $mysqlDBName | wc -l) + counter=0 + while [ $checkDBExists -lt 1 ]; do + sleep 1 + let counter=counter+1 + if [ $counter == $timeout ]; then + echo "Creating MySQL database $mysqlDBName timeout" + return 1 + fi + checkDBExists=$(kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "SELECT schema_name FROM information_schema.schemata WHERE schema_name = '$mysqlDBName';" | grep -o $mysqlDBName | wc -l) + done; + + kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "GRANT ALL PRIVILEGES ON $mysqlDBName.* TO '$mysqlUserName'@'%';" + kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "FLUSH PRIVILEGES;" + checkGrant=$(kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "SHOW GRANTS for $mysqlUserName;" | grep -o "GRANT ALL PRIVILEGES ON \`$mysqlDBName\`\.\*" | wc -l) + counter=0 + while [ $checkGrant -lt 1 ]; do + sleep 1 + let counter=counter+1 + if [ $counter == $timeout ]; then + echo "Granting all privileges on $mysqlDBName to $mysqlUserName timeout" + return 1 + fi + checkGrant=$(kubectl exec -n milvus $mysqlContainer -- mysql -h milvus-mysql -u$mysqlUserName -p$mysqlPassword -e "SHOW GRANTS for $mysqlUserName;" | grep -o "GRANT ALL PRIVILEGES ON \`$mysqlDBName\`\.\*" | wc -l) + done; +} + +function checkStatefulSevers() { + stateful_replicas=$(kubectl describe statefulset -n milvus milvus-ro-servers | grep "Replicas:" | awk '{print $2}') + stateful_running_pods=$(kubectl describe statefulset -n milvus milvus-ro-servers | grep "Pods Status:" | awk '{print $3}') + + counter=0 + prev=$stateful_running_pods + PrintPodStatusMessage "Running milvus-ro-servers Pods: $stateful_running_pods/$stateful_replicas" + while [ $stateful_replicas != $stateful_running_pods ]; do + echo -e "${YELLOW}Wait another 1 sec --- ${counter}${ENDC}" + sleep 1; + + let counter=counter+1 + if [ $counter -eq $timeout ]; then + return 1; + fi + + stateful_running_pods=$(kubectl describe statefulset -n milvus milvus-ro-servers | grep "Pods Status:" | awk '{print $3}') + if [ $stateful_running_pods -ne $prev ]; then + PrintPodStatusMessage "Running milvus-ro-servers Pods: $stateful_running_pods/$stateful_replicas" + fi + prev=$stateful_running_pods + done; + return 0; +} + +function checkDeployment() { + deployment_name=$1 + replicas=$(kubectl describe deployment -n milvus $deployment_name | grep "Replicas:" | awk '{print $2}') + running=$(kubectl get pods -n milvus | grep $deployment_name | grep Running | wc -l) + + counter=0 + prev=$running + PrintPodStatusMessage "Running $deployment_name Pods: $running/$replicas" + while [ $replicas != $running ]; do + echo -e "${YELLOW}Wait another 1 sec --- ${counter}${ENDC}" + sleep 1; + + let counter=counter+1 + if [ $counter == $timeout ]; then + return 1 + fi + + running=$(kubectl get pods -n milvus | grep "$deployment_name" | grep Running | wc -l) + if [ $running -ne $prev ]; then + PrintPodStatusMessage "Running $deployment_name Pods: $running/$replicas" + fi + prev=$running + done +} + + +function startDependencies() { + kubectl apply -f milvus_data_pvc.yaml + kubectl apply -f milvus_configmap.yaml + kubectl apply -f milvus_auxiliary.yaml + + counter=0 + while [ $(kubectl get pvc -n milvus | grep Bound | wc -l) != 4 ]; do + sleep 1; + let counter=counter+1 + if [ $counter == $timeout ]; then + echo "baseup timeout" + return 1 + fi + done + checkDeployment "milvus-mysql" +} + +function startApps() { + counter=0 + errmsg="" + echo -e "${GREEN}${BOLD}Checking required resouces...${NORMAL}${ENDC}" + while [ $counter -lt $timeout ]; do + sleep 1; + if [ $(kubectl get pvc -n milvus 2>/dev/null | grep Bound | wc -l) != 4 ]; then + echo -e "${YELLOW}No pvc. Wait another sec... $counter${ENDC}"; + errmsg='No pvc'; + let counter=counter+1; + continue + fi + if [ $(kubectl get configmap -n milvus 2>/dev/null | grep milvus | wc -l) != 4 ]; then + echo -e "${YELLOW}No configmap. Wait another sec... $counter${ENDC}"; + errmsg='No configmap'; + let counter=counter+1; + continue + fi + if [ $(kubectl get ep -n milvus 2>/dev/null | grep milvus-mysql | awk '{print $2}') == "" ]; then + echo -e "${YELLOW}No mysql. Wait another sec... $counter${ENDC}"; + errmsg='No mysql'; + let counter=counter+1; + continue + fi + # if [ $(kubectl get ep -n milvus 2>/dev/null | grep milvus-redis | awk '{print $2}') == "" ]; then + # echo -e "${NORMAL}${YELLOW}No redis. Wait another sec... $counter${ENDC}"; + # errmsg='No redis'; + # let counter=counter+1; + # continue + # fi + break; + done + + if [ $counter -ge $timeout ]; then + echo -e "${RED}${BOLD}Start APP Error: $errmsg${NORMAL}${ENDC}" + exit 1; + fi + + echo -e "${GREEN}${BOLD}Setup requried database ...${NORMAL}${ENDC}" + setUpMysql + if [ $? -ne 0 ]; then + echo -e "${RED}${BOLD}Setup MySQL database timeout${NORMAL}${ENDC}" + exit 1 + fi + + echo -e "${GREEN}${BOLD}Start servers ...${NORMAL}${ENDC}" + kubectl apply -f milvus_stateful_servers.yaml + kubectl apply -f milvus_write_servers.yaml + + checkStatefulSevers + if [ $? -ne 0 ]; then + echo -e "${RED}${BOLD}Starting milvus-ro-servers timeout${NORMAL}${ENDC}" + exit 1 + fi + + checkDeployment "milvus-wo-servers" + if [ $? -ne 0 ]; then + echo -e "${RED}${BOLD}Starting milvus-wo-servers timeout${NORMAL}${ENDC}" + exit 1 + fi + + echo -e "${GREEN}${BOLD}Start rolebinding ...${NORMAL}${ENDC}" + kubectl apply -f milvus_rbac.yaml + + echo -e "${GREEN}${BOLD}Start proxies ...${NORMAL}${ENDC}" + kubectl apply -f milvus_proxy.yaml + + checkDeployment "milvus-proxy" + if [ $? -ne 0 ]; then + echo -e "${RED}${BOLD}Starting milvus-proxy timeout${NORMAL}${ENDC}" + exit 1 + fi + + # echo -e "${GREEN}${BOLD}Start flower ...${NORMAL}${ENDC}" + # kubectl apply -f milvus_flower.yaml + # checkDeployment "milvus-flower" + # if [ $? -ne 0 ]; then + # echo -e "${RED}${BOLD}Starting milvus-flower timeout${NORMAL}${ENDC}" + # exit 1 + # fi + +} + +function removeApps () { + # kubectl delete -f milvus_flower.yaml 2>/dev/null + kubectl delete -f milvus_proxy.yaml 2>/dev/null + kubectl delete -f milvus_stateful_servers.yaml 2>/dev/null + kubectl delete -f milvus_write_servers.yaml 2>/dev/null + kubectl delete -f milvus_rbac.yaml 2>/dev/null + # kubectl delete -f milvus_monitor.yaml 2>/dev/null +} + +function scaleDeployment() { + deployment_name=$1 + subcommand=$2 + des=$3 + + case $des in + -h|--help|"") + showscaleHelpMessage $subcommand + exit 3 + ;; + esac + + cur=$(kubectl get deployment -n milvus $deployment_name |grep $deployment_name |awk '{split($2, status, "/"); print status[2];}') + echo -e "${GREEN}Current Running ${BOLD}$cur ${GREEN}${deployment_name}, Scaling to ${BOLD}$des ...${ENDC}"; + scalecmd="kubectl scale deployment -n milvus ${deployment_name} --replicas=${des}" + ${scalecmd} + if [ $? -ne 0 ]; then + echo -e "${RED}${BOLD}Scale Error: ${GREEN}${scalecmd}${ENDC}" + exit 1 + fi + + checkDeployment $deployment_name + + if [ $? -ne 0 ]; then + echo -e "${RED}${BOLD}Scale ${deployment_name} timeout${NORMAL}${ENDC}" + scalecmd="kubectl scale deployment -n milvus ${deployment_name} --replicas=${cur}" + ${scalecmd} + if [ $? -ne 0 ]; then + echo -e "${RED}${BOLD}Scale Rollback Error: ${GREEN}${scalecmd}${ENDC}" + exit 2 + fi + echo -e "${BLUE}${BOLD}Scale Rollback to ${cur}${ENDC}" + exit 1 + fi + PrintScaleSuccessMessage $cur $des +} + +function scaleROServers() { + subcommand=$1 + des=$2 + case $des in + -h|--help|"") + showscaleHelpMessage $subcommand + exit 3 + ;; + esac + + cur=$(kubectl get statefulset -n milvus milvus-ro-servers |tail -n 1 |awk '{split($2, status, "/"); print status[2];}') + echo -e "${GREEN}Current Running ${BOLD}$cur ${GREEN}Readonly Servers, Scaling to ${BOLD}$des ...${ENDC}"; + scalecmd="kubectl scale sts milvus-ro-servers -n milvus --replicas=${des}" + ${scalecmd} + if [ $? -ne 0 ]; then + echo -e "${RED}${BOLD}Scale Error: ${GREEN}${scalecmd}${ENDC}" + exit 1 + fi + + checkStatefulSevers + if [ $? -ne 0 ]; then + echo -e "${RED}${BOLD}Scale milvus-ro-servers timeout${NORMAL}${ENDC}" + scalecmd="kubectl scale sts milvus-ro-servers -n milvus --replicas=${cur}" + ${scalecmd} + if [ $? -ne 0 ]; then + echo -e "${RED}${BOLD}Scale Rollback Error: ${GREEN}${scalecmd}${ENDC}" + exit 2 + fi + echo -e "${BLUE}${BOLD}Scale Rollback to ${cur}${ENDC}" + exit 1 + fi + + PrintScaleSuccessMessage $cur $des +} + + +case "$1" in + +cleanup) + kubectl delete -f . 2>/dev/null + echo -e "${BLUE}${BOLD}All resources are removed${NORMAL}${ENDC}" + ;; + +appdown) + removeApps; + echo -e "${BLUE}${BOLD}All pods are removed${NORMAL}${ENDC}" + ;; + +baseup) + startDependencies; + echo -e "${BLUE}${BOLD}All pvc, configmap and services up${NORMAL}${ENDC}" + ;; + +appup) + startApps; + echo -e "${BLUE}${BOLD}All pods up${NORMAL}${ENDC}" + ;; + +allup) + startDependencies; + sleep 2 + startApps; + echo -e "${BLUE}${BOLD}All resources and pods up${NORMAL}${ENDC}" + ;; + +scale-ro-server) + scaleROServers $1 $2 + ;; + +scale-proxy) + scaleDeployment "milvus-proxy" $1 $2 + ;; + +-h|--help|*) + showHelpMessage + ;; + +esac diff --git a/shards/manager.py b/shards/manager.py new file mode 100644 index 0000000000..4157b9343e --- /dev/null +++ b/shards/manager.py @@ -0,0 +1,18 @@ +import fire +from mishards import db, settings + + +class DBHandler: + @classmethod + def create_all(cls): + db.create_all() + + @classmethod + def drop_all(cls): + db.drop_all() + + +if __name__ == '__main__': + db.init_db(settings.DefaultConfig.SQLALCHEMY_DATABASE_URI) + from mishards import models + fire.Fire(DBHandler) diff --git a/shards/mishards/.env.example b/shards/mishards/.env.example new file mode 100644 index 0000000000..f1c812a269 --- /dev/null +++ b/shards/mishards/.env.example @@ -0,0 +1,36 @@ +DEBUG=True + +WOSERVER=tcp://127.0.0.1:19530 +SERVER_PORT=19532 +SERVER_TEST_PORT=19888 + +#SQLALCHEMY_DATABASE_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4 +SQLALCHEMY_DATABASE_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False +SQL_ECHO=False + +#SQLALCHEMY_DATABASE_TEST_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4 +SQLALCHEMY_DATABASE_TEST_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False +SQL_TEST_ECHO=False + +TRACER_PLUGIN_PATH=/tmp/plugins + +# TRACING_TEST_TYPE=jaeger +TRACER_CLASS_NAME=jaeger +TRACING_SERVICE_NAME=fortest +TRACING_SAMPLER_TYPE=const +TRACING_SAMPLER_PARAM=1 +TRACING_LOG_PAYLOAD=True +#TRACING_SAMPLER_TYPE=probabilistic +#TRACING_SAMPLER_PARAM=0.5 + +#DISCOVERY_PLUGIN_PATH= +#DISCOVERY_CLASS_NAME=kubernetes + +DISCOVERY_STATIC_HOSTS=127.0.0.1 +DISCOVERY_STATIC_PORT=19530 + +DISCOVERY_KUBERNETES_NAMESPACE=xp +DISCOVERY_KUBERNETES_POD_PATT=.*-ro-servers-.* +DISCOVERY_KUBERNETES_LABEL_SELECTOR=tier=ro-servers +DISCOVERY_KUBERNETES_POLL_INTERVAL=5 +DISCOVERY_KUBERNETES_IN_CLUSTER=False diff --git a/shards/mishards/__init__.py b/shards/mishards/__init__.py new file mode 100644 index 0000000000..a3c55c4ae3 --- /dev/null +++ b/shards/mishards/__init__.py @@ -0,0 +1,40 @@ +import logging +from mishards import settings +logger = logging.getLogger() + +from mishards.db_base import DB +db = DB() + +from mishards.server import Server +grpc_server = Server() + + +def create_app(testing_config=None): + config = testing_config if testing_config else settings.DefaultConfig + db.init_db(uri=config.SQLALCHEMY_DATABASE_URI, echo=config.SQL_ECHO) + + from mishards.connections import ConnectionMgr + connect_mgr = ConnectionMgr() + + from discovery.factory import DiscoveryFactory + discover = DiscoveryFactory(config.DISCOVERY_PLUGIN_PATH).create(config.DISCOVERY_CLASS_NAME, + conn_mgr=connect_mgr) + + from mishards.grpc_utils import GrpcSpanDecorator + from tracer.factory import TracerFactory + tracer = TracerFactory(config.TRACER_PLUGIN_PATH).create(config.TRACER_CLASS_NAME, + plugin_config=settings.TracingConfig, + span_decorator=GrpcSpanDecorator()) + + from mishards.router.factory import RouterFactory + router = RouterFactory(config.ROUTER_PLUGIN_PATH).create(config.ROUTER_CLASS_NAME, + conn_mgr=connect_mgr) + + grpc_server.init_app(conn_mgr=connect_mgr, + tracer=tracer, + router=router, + discover=discover) + + from mishards import exception_handlers + + return grpc_server diff --git a/shards/mishards/connections.py b/shards/mishards/connections.py new file mode 100644 index 0000000000..618690a099 --- /dev/null +++ b/shards/mishards/connections.py @@ -0,0 +1,154 @@ +import logging +import threading +from functools import wraps +from milvus import Milvus + +from mishards import (settings, exceptions) +from utils import singleton + +logger = logging.getLogger(__name__) + + +class Connection: + def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs): + self.name = name + self.uri = uri + self.max_retry = max_retry + self.retried = 0 + self.conn = Milvus() + self.error_handlers = [] if not error_handlers else error_handlers + self.on_retry_func = kwargs.get('on_retry_func', None) + # self._connect() + + def __str__(self): + return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri) + + def _connect(self, metadata=None): + try: + self.conn.connect(uri=self.uri) + except Exception as e: + if not self.error_handlers: + raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata) + for handler in self.error_handlers: + handler(e, metadata=metadata) + + @property + def can_retry(self): + return self.retried < self.max_retry + + @property + def connected(self): + return self.conn.connected() + + def on_retry(self): + if self.on_retry_func: + self.on_retry_func(self) + else: + self.retried > 1 and logger.warning('{} is retrying {}'.format(self, self.retried)) + + def on_connect(self, metadata=None): + while not self.connected and self.can_retry: + self.retried += 1 + self.on_retry() + self._connect(metadata=metadata) + + if not self.can_retry and not self.connected: + raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry, + metadata=metadata)) + + self.retried = 0 + + def connect(self, func, exception_handler=None): + @wraps(func) + def inner(*args, **kwargs): + self.on_connect() + try: + return func(*args, **kwargs) + except Exception as e: + if exception_handler: + exception_handler(e) + else: + raise e + return inner + + +@singleton +class ConnectionMgr: + def __init__(self): + self.metas = {} + self.conns = {} + + @property + def conn_names(self): + return set(self.metas.keys()) - set(['WOSERVER']) + + def conn(self, name, metadata, throw=False): + c = self.conns.get(name, None) + if not c: + url = self.metas.get(name, None) + if not url: + if not throw: + return None + raise exceptions.ConnectionNotFoundError(message='Connection {} not found'.format(name), + metadata=metadata) + this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY) + threaded = { + threading.get_ident(): this_conn + } + self.conns[name] = threaded + return this_conn + + tid = threading.get_ident() + rconn = c.get(tid, None) + if not rconn: + url = self.metas.get(name, None) + if not url: + if not throw: + return None + raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name), + metadata=metadata) + this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY) + c[tid] = this_conn + return this_conn + + return rconn + + def on_new_meta(self, name, url): + logger.info('Register Connection: name={};url={}'.format(name, url)) + self.metas[name] = url + + def on_duplicate_meta(self, name, url): + if self.metas[name] == url: + return self.on_same_meta(name, url) + + return self.on_diff_meta(name, url) + + def on_same_meta(self, name, url): + # logger.warning('Register same meta: {}:{}'.format(name, url)) + pass + + def on_diff_meta(self, name, url): + logger.warning('Received {} with diff url={}'.format(name, url)) + self.metas[name] = url + self.conns[name] = {} + + def on_unregister_meta(self, name, url): + logger.info('Unregister name={};url={}'.format(name, url)) + self.conns.pop(name, None) + + def on_nonexisted_meta(self, name): + logger.warning('Non-existed meta: {}'.format(name)) + + def register(self, name, url): + meta = self.metas.get(name) + if not meta: + return self.on_new_meta(name, url) + else: + return self.on_duplicate_meta(name, url) + + def unregister(self, name): + logger.info('Unregister Connection: name={}'.format(name)) + url = self.metas.pop(name, None) + if url is None: + return self.on_nonexisted_meta(name) + return self.on_unregister_meta(name, url) diff --git a/shards/mishards/db_base.py b/shards/mishards/db_base.py new file mode 100644 index 0000000000..5f2eee9ba1 --- /dev/null +++ b/shards/mishards/db_base.py @@ -0,0 +1,52 @@ +import logging +from sqlalchemy import create_engine +from sqlalchemy.engine.url import make_url +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker, scoped_session +from sqlalchemy.orm.session import Session as SessionBase + +logger = logging.getLogger(__name__) + + +class LocalSession(SessionBase): + def __init__(self, db, autocommit=False, autoflush=True, **options): + self.db = db + bind = options.pop('bind', None) or db.engine + SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, **options) + + +class DB: + Model = declarative_base() + + def __init__(self, uri=None, echo=False): + self.echo = echo + uri and self.init_db(uri, echo) + self.session_factory = scoped_session(sessionmaker(class_=LocalSession, db=self)) + + def init_db(self, uri, echo=False): + url = make_url(uri) + if url.get_backend_name() == 'sqlite': + self.engine = create_engine(url) + else: + self.engine = create_engine(uri, pool_size=100, pool_recycle=5, pool_timeout=30, + pool_pre_ping=True, + echo=echo, + max_overflow=0) + self.uri = uri + self.url = url + + def __str__(self): + return ''.format(self.url.get_backend_name(), self.url.database) + + @property + def Session(self): + return self.session_factory() + + def remove_session(self): + self.session_factory.remove() + + def drop_all(self): + self.Model.metadata.drop_all(self.engine) + + def create_all(self): + self.Model.metadata.create_all(self.engine) diff --git a/shards/mishards/exception_codes.py b/shards/mishards/exception_codes.py new file mode 100644 index 0000000000..bdd4572dd5 --- /dev/null +++ b/shards/mishards/exception_codes.py @@ -0,0 +1,10 @@ +INVALID_CODE = -1 + +CONNECT_ERROR_CODE = 10001 +CONNECTTION_NOT_FOUND_CODE = 10002 +DB_ERROR_CODE = 10003 + +TABLE_NOT_FOUND_CODE = 20001 +INVALID_ARGUMENT_CODE = 20002 +INVALID_DATE_RANGE_CODE = 20003 +INVALID_TOPK_CODE = 20004 diff --git a/shards/mishards/exception_handlers.py b/shards/mishards/exception_handlers.py new file mode 100644 index 0000000000..c79a6db5a3 --- /dev/null +++ b/shards/mishards/exception_handlers.py @@ -0,0 +1,82 @@ +import logging +from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 +from mishards import grpc_server as server, exceptions + +logger = logging.getLogger(__name__) + + +def resp_handler(err, error_code): + if not isinstance(err, exceptions.BaseException): + return status_pb2.Status(error_code=error_code, reason=str(err)) + + status = status_pb2.Status(error_code=error_code, reason=err.message) + + if err.metadata is None: + return status + + resp_class = err.metadata.get('resp_class', None) + if not resp_class: + return status + + if resp_class == milvus_pb2.BoolReply: + return resp_class(status=status, bool_reply=False) + + if resp_class == milvus_pb2.VectorIds: + return resp_class(status=status, vector_id_array=[]) + + if resp_class == milvus_pb2.TopKQueryResultList: + return resp_class(status=status, topk_query_result=[]) + + if resp_class == milvus_pb2.TableRowCount: + return resp_class(status=status, table_row_count=-1) + + if resp_class == milvus_pb2.TableName: + return resp_class(status=status, table_name=[]) + + if resp_class == milvus_pb2.StringReply: + return resp_class(status=status, string_reply='') + + if resp_class == milvus_pb2.TableSchema: + return milvus_pb2.TableSchema( + status=status + ) + + if resp_class == milvus_pb2.IndexParam: + return milvus_pb2.IndexParam( + table_name=milvus_pb2.TableName( + status=status + ) + ) + + status.error_code = status_pb2.UNEXPECTED_ERROR + return status + + +@server.errorhandler(exceptions.TableNotFoundError) +def TableNotFoundErrorHandler(err): + logger.error(err) + return resp_handler(err, status_pb2.TABLE_NOT_EXISTS) + + +@server.errorhandler(exceptions.InvalidTopKError) +def InvalidTopKErrorHandler(err): + logger.error(err) + return resp_handler(err, status_pb2.ILLEGAL_TOPK) + + +@server.errorhandler(exceptions.InvalidArgumentError) +def InvalidArgumentErrorHandler(err): + logger.error(err) + return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT) + + +@server.errorhandler(exceptions.DBError) +def DBErrorHandler(err): + logger.error(err) + return resp_handler(err, status_pb2.UNEXPECTED_ERROR) + + +@server.errorhandler(exceptions.InvalidRangeError) +def InvalidArgumentErrorHandler(err): + logger.error(err) + return resp_handler(err, status_pb2.ILLEGAL_RANGE) diff --git a/shards/mishards/exceptions.py b/shards/mishards/exceptions.py new file mode 100644 index 0000000000..72839f88d2 --- /dev/null +++ b/shards/mishards/exceptions.py @@ -0,0 +1,38 @@ +import mishards.exception_codes as codes + + +class BaseException(Exception): + code = codes.INVALID_CODE + message = 'BaseException' + + def __init__(self, message='', metadata=None): + self.message = self.__class__.__name__ if not message else message + self.metadata = metadata + + +class ConnectionConnectError(BaseException): + code = codes.CONNECT_ERROR_CODE + + +class ConnectionNotFoundError(BaseException): + code = codes.CONNECTTION_NOT_FOUND_CODE + + +class DBError(BaseException): + code = codes.DB_ERROR_CODE + + +class TableNotFoundError(BaseException): + code = codes.TABLE_NOT_FOUND_CODE + + +class InvalidTopKError(BaseException): + code = codes.INVALID_TOPK_CODE + + +class InvalidArgumentError(BaseException): + code = codes.INVALID_ARGUMENT_CODE + + +class InvalidRangeError(BaseException): + code = codes.INVALID_DATE_RANGE_CODE diff --git a/shards/mishards/factories.py b/shards/mishards/factories.py new file mode 100644 index 0000000000..52c0253b39 --- /dev/null +++ b/shards/mishards/factories.py @@ -0,0 +1,54 @@ +import time +import datetime +import random +import factory +from factory.alchemy import SQLAlchemyModelFactory +from faker import Faker +from faker.providers import BaseProvider + +from milvus.client.types import MetricType +from mishards import db +from mishards.models import Tables, TableFiles + + +class FakerProvider(BaseProvider): + def this_date(self): + t = datetime.datetime.today() + return (t.year - 1900) * 10000 + (t.month - 1) * 100 + t.day + + +factory.Faker.add_provider(FakerProvider) + + +class TablesFactory(SQLAlchemyModelFactory): + class Meta: + model = Tables + sqlalchemy_session = db.session_factory + sqlalchemy_session_persistence = 'commit' + + id = factory.Faker('random_number', digits=16, fix_len=True) + table_id = factory.Faker('uuid4') + state = factory.Faker('random_element', elements=(0, 1)) + dimension = factory.Faker('random_element', elements=(256, 512)) + created_on = int(time.time()) + index_file_size = 0 + engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3)) + metric_type = factory.Faker('random_element', elements=(MetricType.L2, MetricType.IP)) + nlist = 16384 + + +class TableFilesFactory(SQLAlchemyModelFactory): + class Meta: + model = TableFiles + sqlalchemy_session = db.session_factory + sqlalchemy_session_persistence = 'commit' + + id = factory.Faker('random_number', digits=16, fix_len=True) + table = factory.SubFactory(TablesFactory) + engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3)) + file_id = factory.Faker('uuid4') + file_type = factory.Faker('random_element', elements=(0, 1, 2, 3, 4)) + file_size = factory.Faker('random_number') + updated_time = int(time.time()) + created_on = int(time.time()) + date = factory.Faker('this_date') diff --git a/shards/mishards/grpc_utils/__init__.py b/shards/mishards/grpc_utils/__init__.py new file mode 100644 index 0000000000..f5225b2a66 --- /dev/null +++ b/shards/mishards/grpc_utils/__init__.py @@ -0,0 +1,37 @@ +from grpc_opentracing import SpanDecorator +from milvus.grpc_gen import status_pb2 + + +class GrpcSpanDecorator(SpanDecorator): + def __call__(self, span, rpc_info): + status = None + if not rpc_info.response: + return + if isinstance(rpc_info.response, status_pb2.Status): + status = rpc_info.response + else: + try: + status = rpc_info.response.status + except Exception as e: + status = status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, + reason='Should not happen') + + if status.error_code == 0: + return + error_log = {'event': 'error', + 'request': rpc_info.request, + 'response': rpc_info.response + } + span.set_tag('error', True) + span.log_kv(error_log) + + +def mark_grpc_method(func): + setattr(func, 'grpc_method', True) + return func + + +def is_grpc_method(func): + if not func: + return False + return getattr(func, 'grpc_method', False) diff --git a/shards/mishards/grpc_utils/grpc_args_parser.py b/shards/mishards/grpc_utils/grpc_args_parser.py new file mode 100644 index 0000000000..039299803d --- /dev/null +++ b/shards/mishards/grpc_utils/grpc_args_parser.py @@ -0,0 +1,102 @@ +from milvus import Status +from functools import wraps + + +def error_status(func): + @wraps(func) + def inner(*args, **kwargs): + try: + results = func(*args, **kwargs) + except Exception as e: + return Status(code=Status.UNEXPECTED_ERROR, message=str(e)), None + + return Status(code=0, message="Success"), results + + return inner + + +class GrpcArgsParser(object): + + @classmethod + @error_status + def parse_proto_TableSchema(cls, param): + _table_schema = { + 'status': param.status, + 'table_name': param.table_name, + 'dimension': param.dimension, + 'index_file_size': param.index_file_size, + 'metric_type': param.metric_type + } + + return _table_schema + + @classmethod + @error_status + def parse_proto_TableName(cls, param): + return param.table_name + + @classmethod + @error_status + def parse_proto_Index(cls, param): + _index = { + 'index_type': param.index_type, + 'nlist': param.nlist + } + + return _index + + @classmethod + @error_status + def parse_proto_IndexParam(cls, param): + _table_name = param.table_name + _status, _index = cls.parse_proto_Index(param.index) + + if not _status.OK(): + raise Exception("Argument parse error") + + return _table_name, _index + + @classmethod + @error_status + def parse_proto_Command(cls, param): + _cmd = param.cmd + + return _cmd + + @classmethod + @error_status + def parse_proto_Range(cls, param): + _start_value = param.start_value + _end_value = param.end_value + + return _start_value, _end_value + + @classmethod + @error_status + def parse_proto_RowRecord(cls, param): + return list(param.vector_data) + + @classmethod + @error_status + def parse_proto_SearchParam(cls, param): + _table_name = param.table_name + _topk = param.topk + _nprobe = param.nprobe + _status, _range = cls.parse_proto_Range(param.query_range_array) + + if not _status.OK(): + raise Exception("Argument parse error") + + _row_record = param.query_record_array + + return _table_name, _row_record, _range, _topk + + @classmethod + @error_status + def parse_proto_DeleteByRangeParam(cls, param): + _table_name = param.table_name + _range = param.range + _start_value = _range.start_value + _end_value = _range.end_value + + return _table_name, _start_value, _end_value diff --git a/shards/mishards/grpc_utils/grpc_args_wrapper.py b/shards/mishards/grpc_utils/grpc_args_wrapper.py new file mode 100644 index 0000000000..7447dbd995 --- /dev/null +++ b/shards/mishards/grpc_utils/grpc_args_wrapper.py @@ -0,0 +1,4 @@ +# class GrpcArgsWrapper(object): + +# @classmethod +# def proto_TableName(cls): diff --git a/shards/mishards/grpc_utils/test_grpc.py b/shards/mishards/grpc_utils/test_grpc.py new file mode 100644 index 0000000000..9af09e5d0d --- /dev/null +++ b/shards/mishards/grpc_utils/test_grpc.py @@ -0,0 +1,75 @@ +import logging +import opentracing +from mishards.grpc_utils import GrpcSpanDecorator, is_grpc_method +from milvus.grpc_gen import status_pb2, milvus_pb2 + +logger = logging.getLogger(__name__) + + +class FakeTracer(opentracing.Tracer): + pass + + +class FakeSpan(opentracing.Span): + def __init__(self, context, tracer, **kwargs): + super(FakeSpan, self).__init__(tracer, context) + self.reset() + + def set_tag(self, key, value): + self.tags.append({key: value}) + + def log_kv(self, key_values, timestamp=None): + self.logs.append(key_values) + + def reset(self): + self.tags = [] + self.logs = [] + + +class FakeRpcInfo: + def __init__(self, request, response): + self.request = request + self.response = response + + +class TestGrpcUtils: + def test_span_deco(self): + request = 'request' + OK = status_pb2.Status(error_code=status_pb2.SUCCESS, reason='Success') + response = OK + rpc_info = FakeRpcInfo(request=request, response=response) + span = FakeSpan(context=None, tracer=FakeTracer()) + span_deco = GrpcSpanDecorator() + span_deco(span, rpc_info) + assert len(span.logs) == 0 + assert len(span.tags) == 0 + + response = milvus_pb2.BoolReply(status=OK, bool_reply=False) + rpc_info = FakeRpcInfo(request=request, response=response) + span = FakeSpan(context=None, tracer=FakeTracer()) + span_deco = GrpcSpanDecorator() + span_deco(span, rpc_info) + assert len(span.logs) == 0 + assert len(span.tags) == 0 + + response = 1 + rpc_info = FakeRpcInfo(request=request, response=response) + span = FakeSpan(context=None, tracer=FakeTracer()) + span_deco = GrpcSpanDecorator() + span_deco(span, rpc_info) + assert len(span.logs) == 1 + assert len(span.tags) == 1 + + response = 0 + rpc_info = FakeRpcInfo(request=request, response=response) + span = FakeSpan(context=None, tracer=FakeTracer()) + span_deco = GrpcSpanDecorator() + span_deco(span, rpc_info) + assert len(span.logs) == 0 + assert len(span.tags) == 0 + + def test_is_grpc_method(self): + target = 1 + assert not is_grpc_method(target) + target = None + assert not is_grpc_method(target) diff --git a/shards/mishards/hash_ring.py b/shards/mishards/hash_ring.py new file mode 100644 index 0000000000..a97f3f580e --- /dev/null +++ b/shards/mishards/hash_ring.py @@ -0,0 +1,150 @@ +import math +import sys +from bisect import bisect + +if sys.version_info >= (2, 5): + import hashlib + md5_constructor = hashlib.md5 +else: + import md5 + md5_constructor = md5.new + + +class HashRing(object): + def __init__(self, nodes=None, weights=None): + """`nodes` is a list of objects that have a proper __str__ representation. + `weights` is dictionary that sets weights to the nodes. The default + weight is that all nodes are equal. + """ + self.ring = dict() + self._sorted_keys = [] + + self.nodes = nodes + + if not weights: + weights = {} + self.weights = weights + + self._generate_circle() + + def _generate_circle(self): + """Generates the circle. + """ + total_weight = 0 + for node in self.nodes: + total_weight += self.weights.get(node, 1) + + for node in self.nodes: + weight = 1 + + if node in self.weights: + weight = self.weights.get(node) + + factor = math.floor((40 * len(self.nodes) * weight) / total_weight) + + for j in range(0, int(factor)): + b_key = self._hash_digest('%s-%s' % (node, j)) + + for i in range(0, 3): + key = self._hash_val(b_key, lambda x: x + i * 4) + self.ring[key] = node + self._sorted_keys.append(key) + + self._sorted_keys.sort() + + def get_node(self, string_key): + """Given a string key a corresponding node in the hash ring is returned. + + If the hash ring is empty, `None` is returned. + """ + pos = self.get_node_pos(string_key) + if pos is None: + return None + return self.ring[self._sorted_keys[pos]] + + def get_node_pos(self, string_key): + """Given a string key a corresponding node in the hash ring is returned + along with it's position in the ring. + + If the hash ring is empty, (`None`, `None`) is returned. + """ + if not self.ring: + return None + + key = self.gen_key(string_key) + + nodes = self._sorted_keys + pos = bisect(nodes, key) + + if pos == len(nodes): + return 0 + else: + return pos + + def iterate_nodes(self, string_key, distinct=True): + """Given a string key it returns the nodes as a generator that can hold the key. + + The generator iterates one time through the ring + starting at the correct position. + + if `distinct` is set, then the nodes returned will be unique, + i.e. no virtual copies will be returned. + """ + if not self.ring: + yield None, None + + returned_values = set() + + def distinct_filter(value): + if str(value) not in returned_values: + returned_values.add(str(value)) + return value + + pos = self.get_node_pos(string_key) + for key in self._sorted_keys[pos:]: + val = distinct_filter(self.ring[key]) + if val: + yield val + + for i, key in enumerate(self._sorted_keys): + if i < pos: + val = distinct_filter(self.ring[key]) + if val: + yield val + + def gen_key(self, key): + """Given a string key it returns a long value, + this long value represents a place on the hash ring. + + md5 is currently used because it mixes well. + """ + b_key = self._hash_digest(key) + return self._hash_val(b_key, lambda x: x) + + def _hash_val(self, b_key, entry_fn): + return (b_key[entry_fn(3)] << 24) | (b_key[entry_fn(2)] << 16) | ( + b_key[entry_fn(1)] << 8) | b_key[entry_fn(0)] + + def _hash_digest(self, key): + m = md5_constructor() + key = key.encode() + m.update(key) + return m.digest() + + +if __name__ == '__main__': + from collections import defaultdict + servers = [ + '192.168.0.246:11212', '192.168.0.247:11212', '192.168.0.248:11212', + '192.168.0.249:11212' + ] + + ring = HashRing(servers) + keys = ['{}'.format(i) for i in range(100)] + mapped = defaultdict(list) + for k in keys: + server = ring.get_node(k) + mapped[server].append(k) + + for k, v in mapped.items(): + print(k, v) diff --git a/shards/mishards/main.py b/shards/mishards/main.py new file mode 100644 index 0000000000..c0d142607b --- /dev/null +++ b/shards/mishards/main.py @@ -0,0 +1,15 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from mishards import (settings, create_app) + + +def main(): + server = create_app(settings.DefaultConfig) + server.run(port=settings.SERVER_PORT) + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/shards/mishards/models.py b/shards/mishards/models.py new file mode 100644 index 0000000000..4b6c8f9ef4 --- /dev/null +++ b/shards/mishards/models.py @@ -0,0 +1,76 @@ +import logging +from sqlalchemy import (Integer, Boolean, Text, + String, BigInteger, and_, or_, + Column) +from sqlalchemy.orm import relationship, backref + +from mishards import db + +logger = logging.getLogger(__name__) + + +class TableFiles(db.Model): + FILE_TYPE_NEW = 0 + FILE_TYPE_RAW = 1 + FILE_TYPE_TO_INDEX = 2 + FILE_TYPE_INDEX = 3 + FILE_TYPE_TO_DELETE = 4 + FILE_TYPE_NEW_MERGE = 5 + FILE_TYPE_NEW_INDEX = 6 + FILE_TYPE_BACKUP = 7 + + __tablename__ = 'TableFiles' + + id = Column(BigInteger, primary_key=True, autoincrement=True) + table_id = Column(String(50)) + engine_type = Column(Integer) + file_id = Column(String(50)) + file_type = Column(Integer) + file_size = Column(Integer, default=0) + row_count = Column(Integer, default=0) + updated_time = Column(BigInteger) + created_on = Column(BigInteger) + date = Column(Integer) + + table = relationship( + 'Tables', + primaryjoin='and_(foreign(TableFiles.table_id) == Tables.table_id)', + backref=backref('files', uselist=True, lazy='dynamic') + ) + + +class Tables(db.Model): + TO_DELETE = 1 + NORMAL = 0 + + __tablename__ = 'Tables' + + id = Column(BigInteger, primary_key=True, autoincrement=True) + table_id = Column(String(50), unique=True) + state = Column(Integer) + dimension = Column(Integer) + created_on = Column(Integer) + flag = Column(Integer, default=0) + index_file_size = Column(Integer) + engine_type = Column(Integer) + nlist = Column(Integer) + metric_type = Column(Integer) + + def files_to_search(self, date_range=None): + cond = or_( + TableFiles.file_type == TableFiles.FILE_TYPE_RAW, + TableFiles.file_type == TableFiles.FILE_TYPE_TO_INDEX, + TableFiles.file_type == TableFiles.FILE_TYPE_INDEX, + ) + if date_range: + cond = and_( + cond, + or_( + and_(TableFiles.date >= d[0], TableFiles.date < d[1]) for d in date_range + ) + ) + + files = self.files.filter(cond) + + logger.debug('DATE_RANGE: {}'.format(date_range)) + return files diff --git a/shards/mishards/router/__init__.py b/shards/mishards/router/__init__.py new file mode 100644 index 0000000000..4150f3b736 --- /dev/null +++ b/shards/mishards/router/__init__.py @@ -0,0 +1,22 @@ +from mishards import exceptions + + +class RouterMixin: + def __init__(self, conn_mgr): + self.conn_mgr = conn_mgr + + def routing(self, table_name, metadata=None, **kwargs): + raise NotImplemented() + + def connection(self, metadata=None): + conn = self.conn_mgr.conn('WOSERVER', metadata=metadata) + if conn: + conn.on_connect(metadata=metadata) + return conn.conn + + def query_conn(self, name, metadata=None): + conn = self.conn_mgr.conn(name, metadata=metadata) + if not conn: + raise exceptions.ConnectionNotFoundError(name, metadata=metadata) + conn.on_connect(metadata=metadata) + return conn.conn diff --git a/shards/mishards/router/factory.py b/shards/mishards/router/factory.py new file mode 100644 index 0000000000..a8f85c0df8 --- /dev/null +++ b/shards/mishards/router/factory.py @@ -0,0 +1,17 @@ +import os +import logging +from utils.plugins import BaseMixin + +logger = logging.getLogger(__name__) +PLUGIN_PACKAGE_NAME = 'mishards.router.plugins' + + +class RouterFactory(BaseMixin): + PLUGIN_TYPE = 'Router' + + def __init__(self, searchpath=None): + super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME) + + def _create(self, plugin_class, **kwargs): + router = plugin_class.Create(**kwargs) + return router diff --git a/shards/mishards/router/plugins/file_based_hash_ring_router.py b/shards/mishards/router/plugins/file_based_hash_ring_router.py new file mode 100644 index 0000000000..b90935129e --- /dev/null +++ b/shards/mishards/router/plugins/file_based_hash_ring_router.py @@ -0,0 +1,64 @@ +import logging +from sqlalchemy import exc as sqlalchemy_exc +from sqlalchemy import and_ +from mishards.models import Tables +from mishards.router import RouterMixin +from mishards import exceptions, db +from mishards.hash_ring import HashRing + +logger = logging.getLogger(__name__) + + +class Factory(RouterMixin): + name = 'FileBasedHashRingRouter' + + def __init__(self, conn_mgr, **kwargs): + super(Factory, self).__init__(conn_mgr) + + def routing(self, table_name, metadata=None, **kwargs): + range_array = kwargs.pop('range_array', None) + return self._route(table_name, range_array, metadata, **kwargs) + + def _route(self, table_name, range_array, metadata=None, **kwargs): + # PXU TODO: Implement Thread-local Context + # PXU TODO: Session life mgt + try: + table = db.Session.query(Tables).filter( + and_(Tables.table_id == table_name, + Tables.state != Tables.TO_DELETE)).first() + except sqlalchemy_exc.SQLAlchemyError as e: + raise exceptions.DBError(message=str(e), metadata=metadata) + + if not table: + raise exceptions.TableNotFoundError(table_name, metadata=metadata) + files = table.files_to_search(range_array) + db.remove_session() + + servers = self.conn_mgr.conn_names + logger.info('Available servers: {}'.format(servers)) + + ring = HashRing(servers) + + routing = {} + + for f in files: + target_host = ring.get_node(str(f.id)) + sub = routing.get(target_host, None) + if not sub: + routing[target_host] = {'table_id': table_name, 'file_ids': []} + routing[target_host]['file_ids'].append(str(f.id)) + + return routing + + @classmethod + def Create(cls, **kwargs): + conn_mgr = kwargs.pop('conn_mgr', None) + if not conn_mgr: + raise RuntimeError('Cannot find \'conn_mgr\' to initialize \'{}\''.format(self.name)) + router = cls(conn_mgr, **kwargs) + return router + + +def setup(app): + logger.info('Plugin \'{}\' Installed In Package: {}'.format(__file__, app.plugin_package_name)) + app.on_plugin_setup(Factory) diff --git a/shards/mishards/server.py b/shards/mishards/server.py new file mode 100644 index 0000000000..599a00e455 --- /dev/null +++ b/shards/mishards/server.py @@ -0,0 +1,122 @@ +import logging +import grpc +import time +import socket +import inspect +from urllib.parse import urlparse +from functools import wraps +from concurrent import futures +from grpc._cython import cygrpc +from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server +from mishards.grpc_utils import is_grpc_method +from mishards.service_handler import ServiceHandler +from mishards import settings + +logger = logging.getLogger(__name__) + + +class Server: + def __init__(self): + self.pre_run_handlers = set() + self.grpc_methods = set() + self.error_handlers = {} + self.exit_flag = False + + def init_app(self, + conn_mgr, + tracer, + router, + discover, + port=19530, + max_workers=10, + **kwargs): + self.port = int(port) + self.conn_mgr = conn_mgr + self.tracer = tracer + self.router = router + self.discover = discover + + self.server_impl = grpc.server( + thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers), + options=[(cygrpc.ChannelArgKey.max_send_message_length, -1), + (cygrpc.ChannelArgKey.max_receive_message_length, -1)]) + + self.server_impl = self.tracer.decorate(self.server_impl) + + self.register_pre_run_handler(self.pre_run_handler) + + def pre_run_handler(self): + woserver = settings.WOSERVER + url = urlparse(woserver) + ip = socket.gethostbyname(url.hostname) + socket.inet_pton(socket.AF_INET, ip) + self.conn_mgr.register( + 'WOSERVER', '{}://{}:{}'.format(url.scheme, ip, url.port or 80)) + + def register_pre_run_handler(self, func): + logger.info('Regiterring {} into server pre_run_handlers'.format(func)) + self.pre_run_handlers.add(func) + return func + + def wrap_method_with_errorhandler(self, func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + if e.__class__ in self.error_handlers: + return self.error_handlers[e.__class__](e) + raise + + return wrapper + + def errorhandler(self, exception): + if inspect.isclass(exception) and issubclass(exception, Exception): + + def wrapper(func): + self.error_handlers[exception] = func + return func + + return wrapper + return exception + + def on_pre_run(self): + for handler in self.pre_run_handlers: + handler() + self.discover.start() + + def start(self, port=None): + handler_class = self.decorate_handler(ServiceHandler) + add_MilvusServiceServicer_to_server( + handler_class(tracer=self.tracer, + router=self.router), self.server_impl) + self.server_impl.add_insecure_port("[::]:{}".format( + str(port or self.port))) + self.server_impl.start() + + def run(self, port): + logger.info('Milvus server start ......') + port = port or self.port + self.on_pre_run() + + self.start(port) + logger.info('Listening on port {}'.format(port)) + + try: + while not self.exit_flag: + time.sleep(5) + except KeyboardInterrupt: + self.stop() + + def stop(self): + logger.info('Server is shuting down ......') + self.exit_flag = True + self.server_impl.stop(0) + self.tracer.close() + logger.info('Server is closed') + + def decorate_handler(self, handler): + for key, attr in handler.__dict__.items(): + if is_grpc_method(attr): + setattr(handler, key, self.wrap_method_with_errorhandler(attr)) + return handler diff --git a/shards/mishards/service_handler.py b/shards/mishards/service_handler.py new file mode 100644 index 0000000000..2f19152ae6 --- /dev/null +++ b/shards/mishards/service_handler.py @@ -0,0 +1,475 @@ +import logging +import time +import datetime +from collections import defaultdict + +import multiprocessing +from concurrent.futures import ThreadPoolExecutor +from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 +from milvus.grpc_gen.milvus_pb2 import TopKQueryResult +from milvus.client.abstract import Range +from milvus.client import types as Types + +from mishards import (db, settings, exceptions) +from mishards.grpc_utils import mark_grpc_method +from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser +from mishards import utilities + +logger = logging.getLogger(__name__) + + +class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): + MAX_NPROBE = 2048 + MAX_TOPK = 2048 + + def __init__(self, tracer, router, max_workers=multiprocessing.cpu_count(), **kwargs): + self.table_meta = {} + self.error_handlers = {} + self.tracer = tracer + self.router = router + self.max_workers = max_workers + + def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs): + status = status_pb2.Status(error_code=status_pb2.SUCCESS, + reason="Success") + if not files_n_topk_results: + return status, [] + + request_results = defaultdict(list) + + calc_time = time.time() + for files_collection in files_n_topk_results: + if isinstance(files_collection, tuple): + status, _ = files_collection + return status, [] + for request_pos, each_request_results in enumerate( + files_collection.topk_query_result): + request_results[request_pos].extend( + each_request_results.query_result_arrays) + request_results[request_pos] = sorted( + request_results[request_pos], + key=lambda x: x.distance, + reverse=reverse)[:topk] + + calc_time = time.time() - calc_time + logger.info('Merge takes {}'.format(calc_time)) + + results = sorted(request_results.items()) + topk_query_result = [] + + for result in results: + query_result = TopKQueryResult(query_result_arrays=result[1]) + topk_query_result.append(query_result) + + return status, topk_query_result + + def _do_query(self, + context, + table_id, + table_meta, + vectors, + topk, + nprobe, + range_array=None, + **kwargs): + metadata = kwargs.get('metadata', None) + range_array = [ + utilities.range_to_date(r, metadata=metadata) for r in range_array + ] if range_array else None + + routing = {} + p_span = None if self.tracer.empty else context.get_active_span( + ).context + with self.tracer.start_span('get_routing', child_of=p_span): + routing = self.router.routing(table_id, + range_array=range_array, + metadata=metadata) + logger.info('Routing: {}'.format(routing)) + + metadata = kwargs.get('metadata', None) + + rs = [] + all_topk_results = [] + + def search(addr, query_params, vectors, topk, nprobe, **kwargs): + logger.info( + 'Send Search Request: addr={};params={};nq={};topk={};nprobe={}' + .format(addr, query_params, len(vectors), topk, nprobe)) + + conn = self.router.query_conn(addr, metadata=metadata) + start = time.time() + span = kwargs.get('span', None) + span = span if span else (None if self.tracer.empty else + context.get_active_span().context) + + with self.tracer.start_span('search_{}'.format(addr), + child_of=span): + ret = conn.search_vectors_in_files( + table_name=query_params['table_id'], + file_ids=query_params['file_ids'], + query_records=vectors, + top_k=topk, + nprobe=nprobe, + lazy_=True) + end = time.time() + logger.info('search_vectors_in_files takes: {}'.format(end - start)) + + all_topk_results.append(ret) + + with self.tracer.start_span('do_search', child_of=p_span) as span: + with ThreadPoolExecutor(max_workers=self.max_workers) as pool: + for addr, params in routing.items(): + res = pool.submit(search, + addr, + params, + vectors, + topk, + nprobe, + span=span) + rs.append(res) + + for res in rs: + res.result() + + reverse = table_meta.metric_type == Types.MetricType.IP + with self.tracer.start_span('do_merge', child_of=p_span): + return self._do_merge(all_topk_results, + topk, + reverse=reverse, + metadata=metadata) + + def _create_table(self, table_schema): + return self.router.connection().create_table(table_schema) + + @mark_grpc_method + def CreateTable(self, request, context): + _status, _table_schema = Parser.parse_proto_TableSchema(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, + reason=_status.message) + + logger.info('CreateTable {}'.format(_table_schema['table_name'])) + + _status = self._create_table(_table_schema) + + return status_pb2.Status(error_code=_status.code, + reason=_status.message) + + def _has_table(self, table_name, metadata=None): + return self.router.connection(metadata=metadata).has_table(table_name) + + @mark_grpc_method + def HasTable(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + return milvus_pb2.BoolReply(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), + bool_reply=False) + + logger.info('HasTable {}'.format(_table_name)) + + _status, _bool = self._has_table(_table_name, + metadata={'resp_class': milvus_pb2.BoolReply}) + + return milvus_pb2.BoolReply(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), + bool_reply=_bool) + + def _delete_table(self, table_name): + return self.router.connection().delete_table(table_name) + + @mark_grpc_method + def DropTable(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, + reason=_status.message) + + logger.info('DropTable {}'.format(_table_name)) + + _status = self._delete_table(_table_name) + + return status_pb2.Status(error_code=_status.code, + reason=_status.message) + + def _create_index(self, table_name, index): + return self.router.connection().create_index(table_name, index) + + @mark_grpc_method + def CreateIndex(self, request, context): + _status, unpacks = Parser.parse_proto_IndexParam(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, + reason=_status.message) + + _table_name, _index = unpacks + + logger.info('CreateIndex {}'.format(_table_name)) + + # TODO: interface create_table incompleted + _status = self._create_index(_table_name, _index) + + return status_pb2.Status(error_code=_status.code, + reason=_status.message) + + def _add_vectors(self, param, metadata=None): + return self.router.connection(metadata=metadata).add_vectors( + None, None, insert_param=param) + + @mark_grpc_method + def Insert(self, request, context): + logger.info('Insert') + # TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array' + _status, _ids = self._add_vectors( + metadata={'resp_class': milvus_pb2.VectorIds}, param=request) + return milvus_pb2.VectorIds(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), + vector_id_array=_ids) + + @mark_grpc_method + def Search(self, request, context): + + table_name = request.table_name + + topk = request.topk + nprobe = request.nprobe + + logger.info('Search {}: topk={} nprobe={}'.format( + table_name, topk, nprobe)) + + metadata = {'resp_class': milvus_pb2.TopKQueryResultList} + + if nprobe > self.MAX_NPROBE or nprobe <= 0: + raise exceptions.InvalidArgumentError( + message='Invalid nprobe: {}'.format(nprobe), metadata=metadata) + + if topk > self.MAX_TOPK or topk <= 0: + raise exceptions.InvalidTopKError( + message='Invalid topk: {}'.format(topk), metadata=metadata) + + table_meta = self.table_meta.get(table_name, None) + + if not table_meta: + status, info = self.router.connection( + metadata=metadata).describe_table(table_name) + if not status.OK(): + raise exceptions.TableNotFoundError(table_name, + metadata=metadata) + + self.table_meta[table_name] = info + table_meta = info + + start = time.time() + + query_record_array = [] + + for query_record in request.query_record_array: + query_record_array.append(list(query_record.vector_data)) + + query_range_array = [] + for query_range in request.query_range_array: + query_range_array.append( + Range(query_range.start_value, query_range.end_value)) + + status, results = self._do_query(context, + table_name, + table_meta, + query_record_array, + topk, + nprobe, + query_range_array, + metadata=metadata) + + now = time.time() + logger.info('SearchVector takes: {}'.format(now - start)) + + topk_result_list = milvus_pb2.TopKQueryResultList( + status=status_pb2.Status(error_code=status.error_code, + reason=status.reason), + topk_query_result=results) + return topk_result_list + + @mark_grpc_method + def SearchInFiles(self, request, context): + raise NotImplemented() + + def _describe_table(self, table_name, metadata=None): + return self.router.connection(metadata=metadata).describe_table(table_name) + + @mark_grpc_method + def DescribeTable(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + return milvus_pb2.TableSchema(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), ) + + metadata = {'resp_class': milvus_pb2.TableSchema} + + logger.info('DescribeTable {}'.format(_table_name)) + _status, _table = self._describe_table(metadata=metadata, + table_name=_table_name) + + if _status.OK(): + return milvus_pb2.TableSchema( + table_name=_table_name, + index_file_size=_table.index_file_size, + dimension=_table.dimension, + metric_type=_table.metric_type, + status=status_pb2.Status(error_code=_status.code, + reason=_status.message), + ) + + return milvus_pb2.TableSchema( + table_name=_table_name, + status=status_pb2.Status(error_code=_status.code, + reason=_status.message), + ) + + def _count_table(self, table_name, metadata=None): + return self.router.connection( + metadata=metadata).get_table_row_count(table_name) + + @mark_grpc_method + def CountTable(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + status = status_pb2.Status(error_code=_status.code, + reason=_status.message) + + return milvus_pb2.TableRowCount(status=status) + + logger.info('CountTable {}'.format(_table_name)) + + metadata = {'resp_class': milvus_pb2.TableRowCount} + _status, _count = self._count_table(_table_name, metadata=metadata) + + return milvus_pb2.TableRowCount( + status=status_pb2.Status(error_code=_status.code, + reason=_status.message), + table_row_count=_count if isinstance(_count, int) else -1) + + def _get_server_version(self, metadata=None): + return self.router.connection(metadata=metadata).server_version() + + @mark_grpc_method + def Cmd(self, request, context): + _status, _cmd = Parser.parse_proto_Command(request) + logger.info('Cmd: {}'.format(_cmd)) + + if not _status.OK(): + return milvus_pb2.StringReply(status=status_pb2.Status( + error_code=_status.code, reason=_status.message)) + + metadata = {'resp_class': milvus_pb2.StringReply} + + if _cmd == 'version': + _status, _reply = self._get_server_version(metadata=metadata) + else: + _status, _reply = self.router.connection( + metadata=metadata).server_status() + + return milvus_pb2.StringReply(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), + string_reply=_reply) + + def _show_tables(self, metadata=None): + return self.router.connection(metadata=metadata).show_tables() + + @mark_grpc_method + def ShowTables(self, request, context): + logger.info('ShowTables') + metadata = {'resp_class': milvus_pb2.TableName} + _status, _results = self._show_tables(metadata=metadata) + + return milvus_pb2.TableNameList(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), + table_names=_results) + + def _delete_by_range(self, table_name, start_date, end_date): + return self.router.connection().delete_vectors_by_range(table_name, + start_date, + end_date) + + @mark_grpc_method + def DeleteByRange(self, request, context): + _status, unpacks = \ + Parser.parse_proto_DeleteByRangeParam(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, + reason=_status.message) + + _table_name, _start_date, _end_date = unpacks + + logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date, + _end_date)) + _status = self._delete_by_range(_table_name, _start_date, _end_date) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) + + def _preload_table(self, table_name): + return self.router.connection().preload_table(table_name) + + @mark_grpc_method + def PreloadTable(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, + reason=_status.message) + + logger.info('PreloadTable {}'.format(_table_name)) + _status = self._preload_table(_table_name) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) + + def _describe_index(self, table_name, metadata=None): + return self.router.connection(metadata=metadata).describe_index(table_name) + + @mark_grpc_method + def DescribeIndex(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + return milvus_pb2.IndexParam(status=status_pb2.Status( + error_code=_status.code, reason=_status.message)) + + metadata = {'resp_class': milvus_pb2.IndexParam} + + logger.info('DescribeIndex {}'.format(_table_name)) + _status, _index_param = self._describe_index(table_name=_table_name, + metadata=metadata) + + if not _index_param: + return milvus_pb2.IndexParam(status=status_pb2.Status( + error_code=_status.code, reason=_status.message)) + + _index = milvus_pb2.Index(index_type=_index_param._index_type, + nlist=_index_param._nlist) + + return milvus_pb2.IndexParam(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), + table_name=_table_name, + index=_index) + + def _drop_index(self, table_name): + return self.router.connection().drop_index(table_name) + + @mark_grpc_method + def DropIndex(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, + reason=_status.message) + + logger.info('DropIndex {}'.format(_table_name)) + _status = self._drop_index(_table_name) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) diff --git a/shards/mishards/settings.py b/shards/mishards/settings.py new file mode 100644 index 0000000000..8d7361dddc --- /dev/null +++ b/shards/mishards/settings.py @@ -0,0 +1,69 @@ +import sys +import os + +from environs import Env +env = Env() + +FROM_EXAMPLE = env.bool('FROM_EXAMPLE', False) +if FROM_EXAMPLE: + from dotenv import load_dotenv + load_dotenv('./mishards/.env.example') +else: + env.read_env() + + +DEBUG = env.bool('DEBUG', False) +MAX_RETRY = env.int('MAX_RETRY', 3) + +LOG_LEVEL = env.str('LOG_LEVEL', 'DEBUG' if DEBUG else 'INFO') +LOG_PATH = env.str('LOG_PATH', '/tmp/mishards') +LOG_NAME = env.str('LOG_NAME', 'logfile') +TIMEZONE = env.str('TIMEZONE', 'UTC') + +from utils.logger_helper import config +config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE) + +SERVER_PORT = env.int('SERVER_PORT', 19530) +SERVER_TEST_PORT = env.int('SERVER_TEST_PORT', 19530) +WOSERVER = env.str('WOSERVER') + + +class TracingConfig: + TRACING_SERVICE_NAME = env.str('TRACING_SERVICE_NAME', 'mishards') + TRACING_VALIDATE = env.bool('TRACING_VALIDATE', True) + TRACING_LOG_PAYLOAD = env.bool('TRACING_LOG_PAYLOAD', False) + TRACING_CONFIG = { + 'sampler': { + 'type': env.str('TRACING_SAMPLER_TYPE', 'const'), + 'param': env.str('TRACING_SAMPLER_PARAM', "1"), + }, + 'local_agent': { + 'reporting_host': env.str('TRACING_REPORTING_HOST', '127.0.0.1'), + 'reporting_port': env.str('TRACING_REPORTING_PORT', '5775') + }, + 'logging': env.bool('TRACING_LOGGING', True) + } + DEFAULT_TRACING_CONFIG = { + 'sampler': { + 'type': env.str('TRACING_SAMPLER_TYPE', 'const'), + 'param': env.str('TRACING_SAMPLER_PARAM', "0"), + } + } + + +class DefaultConfig: + SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI') + SQL_ECHO = env.bool('SQL_ECHO', False) + TRACER_PLUGIN_PATH = env.str('TRACER_PLUGIN_PATH', '') + TRACER_CLASS_NAME = env.str('TRACER_CLASS_NAME', '') + ROUTER_PLUGIN_PATH = env.str('ROUTER_PLUGIN_PATH', '') + ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_NAME', 'FileBasedHashRingRouter') + DISCOVERY_PLUGIN_PATH = env.str('DISCOVERY_PLUGIN_PATH', '') + DISCOVERY_CLASS_NAME = env.str('DISCOVERY_CLASS_NAME', 'static') + + +class TestingConfig(DefaultConfig): + SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_TEST_URI', '') + SQL_ECHO = env.bool('SQL_TEST_ECHO', False) + TRACER_CLASS_NAME = env.str('TRACER_CLASS_TEST_NAME', '') + ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_TEST_NAME', 'FileBasedHashRingRouter') diff --git a/shards/mishards/test_connections.py b/shards/mishards/test_connections.py new file mode 100644 index 0000000000..819d2e03da --- /dev/null +++ b/shards/mishards/test_connections.py @@ -0,0 +1,101 @@ +import logging +import pytest +import mock + +from milvus import Milvus +from mishards.connections import (ConnectionMgr, Connection) +from mishards import exceptions + +logger = logging.getLogger(__name__) + + +@pytest.mark.usefixtures('app') +class TestConnection: + def test_manager(self): + mgr = ConnectionMgr() + + mgr.register('pod1', '111') + mgr.register('pod2', '222') + mgr.register('pod2', '222') + mgr.register('pod2', '2222') + assert len(mgr.conn_names) == 2 + + mgr.unregister('pod1') + assert len(mgr.conn_names) == 1 + + mgr.unregister('pod2') + assert len(mgr.conn_names) == 0 + + mgr.register('WOSERVER', 'xxxx') + assert len(mgr.conn_names) == 0 + + assert not mgr.conn('XXXX', None) + with pytest.raises(exceptions.ConnectionNotFoundError): + mgr.conn('XXXX', None, True) + + mgr.conn('WOSERVER', None) + + def test_connection(self): + class Conn: + def __init__(self, state): + self.state = state + + def connect(self, uri): + return self.state + + def connected(self): + return self.state + + FAIL_CONN = Conn(False) + PASS_CONN = Conn(True) + + class Retry: + def __init__(self): + self.times = 0 + + def __call__(self, conn): + self.times += 1 + logger.info('Retrying {}'.format(self.times)) + + class Func(): + def __init__(self): + self.executed = False + + def __call__(self): + self.executed = True + + max_retry = 3 + + RetryObj = Retry() + + c = Connection('client', + uri='xx', + max_retry=max_retry, + on_retry_func=RetryObj) + c.conn = FAIL_CONN + ff = Func() + this_connect = c.connect(func=ff) + with pytest.raises(exceptions.ConnectionConnectError): + this_connect() + assert RetryObj.times == max_retry + assert not ff.executed + RetryObj = Retry() + + c.conn = PASS_CONN + this_connect = c.connect(func=ff) + this_connect() + assert ff.executed + assert RetryObj.times == 0 + + this_connect = c.connect(func=None) + with pytest.raises(TypeError): + this_connect() + + errors = [] + + def error_handler(err): + errors.append(err) + + this_connect = c.connect(func=None, exception_handler=error_handler) + this_connect() + assert len(errors) == 1 diff --git a/shards/mishards/test_models.py b/shards/mishards/test_models.py new file mode 100644 index 0000000000..d60b62713e --- /dev/null +++ b/shards/mishards/test_models.py @@ -0,0 +1,39 @@ +import logging +import pytest +from mishards.factories import TableFiles, Tables, TableFilesFactory, TablesFactory +from mishards import db, create_app, settings +from mishards.factories import ( + Tables, TableFiles, + TablesFactory, TableFilesFactory +) + +logger = logging.getLogger(__name__) + + +@pytest.mark.usefixtures('app') +class TestModels: + def test_files_to_search(self): + table = TablesFactory() + new_files_cnt = 5 + to_index_cnt = 10 + raw_cnt = 20 + backup_cnt = 12 + to_delete_cnt = 9 + index_cnt = 8 + new_index_cnt = 6 + new_merge_cnt = 11 + + new_files = TableFilesFactory.create_batch(new_files_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW, date=110) + to_index_files = TableFilesFactory.create_batch(to_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_INDEX, date=110) + raw_files = TableFilesFactory.create_batch(raw_cnt, table=table, file_type=TableFiles.FILE_TYPE_RAW, date=120) + backup_files = TableFilesFactory.create_batch(backup_cnt, table=table, file_type=TableFiles.FILE_TYPE_BACKUP, date=110) + index_files = TableFilesFactory.create_batch(index_cnt, table=table, file_type=TableFiles.FILE_TYPE_INDEX, date=110) + new_index_files = TableFilesFactory.create_batch(new_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW_INDEX, date=110) + new_merge_files = TableFilesFactory.create_batch(new_merge_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW_MERGE, date=110) + to_delete_files = TableFilesFactory.create_batch(to_delete_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_DELETE, date=110) + assert table.files_to_search().count() == raw_cnt + index_cnt + to_index_cnt + + assert table.files_to_search([(100, 115)]).count() == index_cnt + to_index_cnt + assert table.files_to_search([(111, 120)]).count() == 0 + assert table.files_to_search([(111, 121)]).count() == raw_cnt + assert table.files_to_search([(110, 121)]).count() == raw_cnt + index_cnt + to_index_cnt diff --git a/shards/mishards/test_server.py b/shards/mishards/test_server.py new file mode 100644 index 0000000000..f0cde2184c --- /dev/null +++ b/shards/mishards/test_server.py @@ -0,0 +1,279 @@ +import logging +import pytest +import mock +import datetime +import random +import faker +import inspect +from milvus import Milvus +from milvus.client.types import Status, IndexType, MetricType +from milvus.client.abstract import IndexParam, TableSchema +from milvus.grpc_gen import status_pb2, milvus_pb2 +from mishards import db, create_app, settings +from mishards.service_handler import ServiceHandler +from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser +from mishards.factories import TableFilesFactory, TablesFactory, TableFiles, Tables +from mishards.router import RouterMixin + +logger = logging.getLogger(__name__) + +OK = Status(code=Status.SUCCESS, message='Success') +BAD = Status(code=Status.PERMISSION_DENIED, message='Fail') + + +@pytest.mark.usefixtures('started_app') +class TestServer: + @property + def client(self): + m = Milvus() + m.connect(host='localhost', port=settings.SERVER_TEST_PORT) + return m + + def test_server_start(self, started_app): + assert started_app.conn_mgr.metas.get('WOSERVER') == settings.WOSERVER + + def test_cmd(self, started_app): + ServiceHandler._get_server_version = mock.MagicMock(return_value=(OK, + '')) + status, _ = self.client.server_version() + assert status.OK() + + Parser.parse_proto_Command = mock.MagicMock(return_value=(BAD, 'cmd')) + status, _ = self.client.server_version() + assert not status.OK() + + def test_drop_index(self, started_app): + table_name = inspect.currentframe().f_code.co_name + ServiceHandler._drop_index = mock.MagicMock(return_value=OK) + status = self.client.drop_index(table_name) + assert status.OK() + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + status = self.client.drop_index(table_name) + assert not status.OK() + + def test_describe_index(self, started_app): + table_name = inspect.currentframe().f_code.co_name + index_type = IndexType.FLAT + nlist = 1 + index_param = IndexParam(table_name=table_name, + index_type=index_type, + nlist=nlist) + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_name)) + ServiceHandler._describe_index = mock.MagicMock( + return_value=(OK, index_param)) + status, ret = self.client.describe_index(table_name) + assert status.OK() + assert ret._table_name == index_param._table_name + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + status, _ = self.client.describe_index(table_name) + assert not status.OK() + + def test_preload(self, started_app): + table_name = inspect.currentframe().f_code.co_name + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_name)) + ServiceHandler._preload_table = mock.MagicMock(return_value=OK) + status = self.client.preload_table(table_name) + assert status.OK() + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + status = self.client.preload_table(table_name) + assert not status.OK() + + @pytest.mark.skip + def test_delete_by_range(self, started_app): + table_name = inspect.currentframe().f_code.co_name + + unpacked = table_name, datetime.datetime.today( + ), datetime.datetime.today() + + Parser.parse_proto_DeleteByRangeParam = mock.MagicMock( + return_value=(OK, unpacked)) + ServiceHandler._delete_by_range = mock.MagicMock(return_value=OK) + status = self.client.delete_vectors_by_range( + *unpacked) + assert status.OK() + + Parser.parse_proto_DeleteByRangeParam = mock.MagicMock( + return_value=(BAD, unpacked)) + status = self.client.delete_vectors_by_range( + *unpacked) + assert not status.OK() + + def test_count_table(self, started_app): + table_name = inspect.currentframe().f_code.co_name + count = random.randint(100, 200) + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_name)) + ServiceHandler._count_table = mock.MagicMock(return_value=(OK, count)) + status, ret = self.client.get_table_row_count(table_name) + assert status.OK() + assert ret == count + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + status, _ = self.client.get_table_row_count(table_name) + assert not status.OK() + + def test_show_tables(self, started_app): + tables = ['t1', 't2'] + ServiceHandler._show_tables = mock.MagicMock(return_value=(OK, tables)) + status, ret = self.client.show_tables() + assert status.OK() + assert ret == tables + + def test_describe_table(self, started_app): + table_name = inspect.currentframe().f_code.co_name + dimension = 128 + nlist = 1 + table_schema = TableSchema(table_name=table_name, + index_file_size=100, + metric_type=MetricType.L2, + dimension=dimension) + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_schema.table_name)) + ServiceHandler._describe_table = mock.MagicMock( + return_value=(OK, table_schema)) + status, _ = self.client.describe_table(table_name) + assert status.OK() + + ServiceHandler._describe_table = mock.MagicMock( + return_value=(BAD, table_schema)) + status, _ = self.client.describe_table(table_name) + assert not status.OK() + + Parser.parse_proto_TableName = mock.MagicMock(return_value=(BAD, + 'cmd')) + status, ret = self.client.describe_table(table_name) + assert not status.OK() + + def test_insert(self, started_app): + table_name = inspect.currentframe().f_code.co_name + vectors = [[random.random() for _ in range(16)] for _ in range(10)] + ids = [random.randint(1000000, 20000000) for _ in range(10)] + ServiceHandler._add_vectors = mock.MagicMock(return_value=(OK, ids)) + status, ret = self.client.add_vectors( + table_name=table_name, records=vectors) + assert status.OK() + assert ids == ret + + def test_create_index(self, started_app): + table_name = inspect.currentframe().f_code.co_name + unpacks = table_name, None + Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(OK, + unpacks)) + ServiceHandler._create_index = mock.MagicMock(return_value=OK) + status = self.client.create_index(table_name=table_name) + assert status.OK() + + Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(BAD, + None)) + status = self.client.create_index(table_name=table_name) + assert not status.OK() + + def test_drop_table(self, started_app): + table_name = inspect.currentframe().f_code.co_name + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_name)) + ServiceHandler._delete_table = mock.MagicMock(return_value=OK) + status = self.client.delete_table(table_name=table_name) + assert status.OK() + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + status = self.client.delete_table(table_name=table_name) + assert not status.OK() + + def test_has_table(self, started_app): + table_name = inspect.currentframe().f_code.co_name + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_name)) + ServiceHandler._has_table = mock.MagicMock(return_value=(OK, True)) + has = self.client.has_table(table_name=table_name) + assert has + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + status, has = self.client.has_table(table_name=table_name) + assert not status.OK() + assert not has + + def test_create_table(self, started_app): + table_name = inspect.currentframe().f_code.co_name + dimension = 128 + table_schema = dict(table_name=table_name, + index_file_size=100, + metric_type=MetricType.L2, + dimension=dimension) + + ServiceHandler._create_table = mock.MagicMock(return_value=OK) + status = self.client.create_table(table_schema) + assert status.OK() + + Parser.parse_proto_TableSchema = mock.MagicMock(return_value=(BAD, + None)) + status = self.client.create_table(table_schema) + assert not status.OK() + + def random_data(self, n, dimension): + return [[random.random() for _ in range(dimension)] for _ in range(n)] + + def test_search(self, started_app): + table_name = inspect.currentframe().f_code.co_name + to_index_cnt = random.randint(10, 20) + table = TablesFactory(table_id=table_name, state=Tables.NORMAL) + to_index_files = TableFilesFactory.create_batch( + to_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_INDEX) + topk = random.randint(5, 10) + nq = random.randint(5, 10) + param = { + 'table_name': table_name, + 'query_records': self.random_data(nq, table.dimension), + 'top_k': topk, + 'nprobe': 2049 + } + + result = [ + milvus_pb2.TopKQueryResult(query_result_arrays=[ + milvus_pb2.QueryResult(id=i, distance=random.random()) + for i in range(topk) + ]) for i in range(nq) + ] + + mock_results = milvus_pb2.TopKQueryResultList(status=status_pb2.Status( + error_code=status_pb2.SUCCESS, reason="Success"), + topk_query_result=result) + + table_schema = TableSchema(table_name=table_name, + index_file_size=table.index_file_size, + metric_type=table.metric_type, + dimension=table.dimension) + + status, _ = self.client.search_vectors(**param) + assert status.code == Status.ILLEGAL_ARGUMENT + + param['nprobe'] = 2048 + RouterMixin.connection = mock.MagicMock(return_value=Milvus()) + RouterMixin.query_conn = mock.MagicMock(return_value=Milvus()) + Milvus.describe_table = mock.MagicMock(return_value=(BAD, + table_schema)) + status, ret = self.client.search_vectors(**param) + assert status.code == Status.TABLE_NOT_EXISTS + + Milvus.describe_table = mock.MagicMock(return_value=(OK, table_schema)) + Milvus.search_vectors_in_files = mock.MagicMock( + return_value=mock_results) + + status, ret = self.client.search_vectors(**param) + assert status.OK() + assert len(ret) == nq diff --git a/shards/mishards/utilities.py b/shards/mishards/utilities.py new file mode 100644 index 0000000000..42e982b5f1 --- /dev/null +++ b/shards/mishards/utilities.py @@ -0,0 +1,20 @@ +import datetime +from mishards import exceptions + + +def format_date(start, end): + return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day, + (end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day) + + +def range_to_date(range_obj, metadata=None): + try: + start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d') + end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d') + assert start < end + except (ValueError, AssertionError): + raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format( + range_obj.start_date, range_obj.end_date), + metadata=metadata) + + return format_date(start, end) diff --git a/shards/requirements.txt b/shards/requirements.txt new file mode 100644 index 0000000000..14bdde2a06 --- /dev/null +++ b/shards/requirements.txt @@ -0,0 +1,37 @@ +environs==4.2.0 +factory-boy==2.12.0 +Faker==1.0.7 +fire==0.1.3 +google-auth==1.6.3 +grpcio==1.22.0 +grpcio-tools==1.22.0 +kubernetes==10.0.1 +MarkupSafe==1.1.1 +marshmallow==2.19.5 +pymysql==0.9.3 +protobuf==3.9.1 +py==1.8.0 +pyasn1==0.4.7 +pyasn1-modules==0.2.6 +pylint==2.3.1 +pymilvus-test==0.2.28 +#pymilvus==0.2.0 +pyparsing==2.4.0 +pytest==4.6.3 +pytest-level==0.1.1 +pytest-print==0.1.2 +pytest-repeat==0.8.0 +pytest-timeout==1.3.3 +python-dateutil==2.8.0 +python-dotenv==0.10.3 +pytz==2019.1 +requests==2.22.0 +requests-oauthlib==1.2.0 +rsa==4.0 +six==1.12.0 +SQLAlchemy==1.3.5 +urllib3==1.25.3 +jaeger-client>=3.4.0 +grpcio-opentracing>=1.0 +mock==2.0.0 +pluginbase==1.0.0 diff --git a/shards/setup.cfg b/shards/setup.cfg new file mode 100644 index 0000000000..4a88432914 --- /dev/null +++ b/shards/setup.cfg @@ -0,0 +1,4 @@ +[tool:pytest] +testpaths = mishards +log_cli=true +log_cli_level=info diff --git a/shards/tracer/__init__.py b/shards/tracer/__init__.py new file mode 100644 index 0000000000..64a5b50d15 --- /dev/null +++ b/shards/tracer/__init__.py @@ -0,0 +1,43 @@ +from contextlib import contextmanager + + +def empty_server_interceptor_decorator(target_server, interceptor): + return target_server + + +@contextmanager +def EmptySpan(*args, **kwargs): + yield None + return + + +class Tracer: + def __init__(self, + tracer=None, + interceptor=None, + server_decorator=empty_server_interceptor_decorator): + self.tracer = tracer + self.interceptor = interceptor + self.server_decorator = server_decorator + + def decorate(self, server): + return self.server_decorator(server, self.interceptor) + + @property + def empty(self): + return self.tracer is None + + def close(self): + self.tracer and self.tracer.close() + + def start_span(self, + operation_name=None, + child_of=None, + references=None, + tags=None, + start_time=None, + ignore_active_span=False): + if self.empty: + return EmptySpan() + return self.tracer.start_span(operation_name, child_of, references, + tags, start_time, ignore_active_span) diff --git a/shards/tracer/factory.py b/shards/tracer/factory.py new file mode 100644 index 0000000000..0e54a5aeb6 --- /dev/null +++ b/shards/tracer/factory.py @@ -0,0 +1,27 @@ +import os +import logging +from tracer import Tracer +from utils.plugins import BaseMixin + +logger = logging.getLogger(__name__) +PLUGIN_PACKAGE_NAME = 'tracer.plugins' + + +class TracerFactory(BaseMixin): + PLUGIN_TYPE = 'Tracer' + + def __init__(self, searchpath=None): + super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME) + + def create(self, class_name, **kwargs): + if not class_name: + return Tracer() + return super().create(class_name, **kwargs) + + def _create(self, plugin_class, **kwargs): + plugin_config = kwargs.pop('plugin_config', None) + if not plugin_config: + raise RuntimeError('\'{}\' Plugin Config is Required!'.format(self.PLUGIN_TYPE)) + + plugin = plugin_class.Create(plugin_config=plugin_config, **kwargs) + return plugin diff --git a/shards/tracer/plugins/jaeger_factory.py b/shards/tracer/plugins/jaeger_factory.py new file mode 100644 index 0000000000..923f2f805d --- /dev/null +++ b/shards/tracer/plugins/jaeger_factory.py @@ -0,0 +1,35 @@ +import logging +from jaeger_client import Config +from grpc_opentracing.grpcext import intercept_server +from grpc_opentracing import open_tracing_server_interceptor +from tracer import Tracer + +logger = logging.getLogger(__name__) + +PLUGIN_NAME = __file__ + + +class JaegerFactory: + name = 'jaeger' + @classmethod + def Create(cls, plugin_config, **kwargs): + tracing_config = plugin_config.TRACING_CONFIG + span_decorator = kwargs.pop('span_decorator', None) + service_name = plugin_config.TRACING_SERVICE_NAME + validate = plugin_config.TRACING_VALIDATE + config = Config(config=tracing_config, + service_name=service_name, + validate=validate) + + tracer = config.initialize_tracer() + tracer_interceptor = open_tracing_server_interceptor( + tracer, + log_payloads=plugin_config.TRACING_LOG_PAYLOAD, + span_decorator=span_decorator) + + return Tracer(tracer, tracer_interceptor, intercept_server) + + +def setup(app): + logger.info('Plugin \'{}\' Installed In Package: {}'.format(PLUGIN_NAME, app.plugin_package_name)) + app.on_plugin_setup(JaegerFactory) diff --git a/shards/utils/__init__.py b/shards/utils/__init__.py new file mode 100644 index 0000000000..cf444c0680 --- /dev/null +++ b/shards/utils/__init__.py @@ -0,0 +1,18 @@ +from functools import wraps + + +def singleton(cls): + instances = {} + @wraps(cls) + def getinstance(*args, **kw): + if cls not in instances: + instances[cls] = cls(*args, **kw) + return instances[cls] + return getinstance + + +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ diff --git a/shards/utils/logger_helper.py b/shards/utils/logger_helper.py new file mode 100644 index 0000000000..b4e3b9c5b6 --- /dev/null +++ b/shards/utils/logger_helper.py @@ -0,0 +1,152 @@ +import os +import datetime +from pytz import timezone +from logging import Filter +import logging.config + + +class InfoFilter(logging.Filter): + def filter(self, rec): + return rec.levelno == logging.INFO + + +class DebugFilter(logging.Filter): + def filter(self, rec): + return rec.levelno == logging.DEBUG + + +class WarnFilter(logging.Filter): + def filter(self, rec): + return rec.levelno == logging.WARN + + +class ErrorFilter(logging.Filter): + def filter(self, rec): + return rec.levelno == logging.ERROR + + +class CriticalFilter(logging.Filter): + def filter(self, rec): + return rec.levelno == logging.CRITICAL + + +COLORS = { + 'HEADER': '\033[95m', + 'INFO': '\033[92m', + 'DEBUG': '\033[94m', + 'WARNING': '\033[93m', + 'ERROR': '\033[95m', + 'CRITICAL': '\033[91m', + 'ENDC': '\033[0m', +} + + +class ColorFulFormatColMixin: + def format_col(self, message_str, level_name): + if level_name in COLORS.keys(): + message_str = COLORS.get(level_name) + message_str + COLORS.get( + 'ENDC') + return message_str + + +class ColorfulFormatter(logging.Formatter, ColorFulFormatColMixin): + def format(self, record): + message_str = super(ColorfulFormatter, self).format(record) + + return self.format_col(message_str, level_name=record.levelname) + + +def config(log_level, log_path, name, tz='UTC'): + def build_log_file(level, log_path, name, tz): + utc_now = datetime.datetime.utcnow() + utc_tz = timezone('UTC') + local_tz = timezone(tz) + tznow = utc_now.replace(tzinfo=utc_tz).astimezone(local_tz) + return '{}-{}-{}.log'.format(os.path.join(log_path, name), tznow.strftime("%m-%d-%Y-%H:%M:%S"), + level) + + if not os.path.exists(log_path): + os.makedirs(log_path) + + LOGGING = { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'default': { + 'format': '%(asctime)s | %(levelname)s | %(name)s | %(threadName)s: %(message)s (%(filename)s:%(lineno)s)', + }, + 'colorful_console': { + 'format': '%(asctime)s | %(levelname)s | %(name)s | %(threadName)s: %(message)s (%(filename)s:%(lineno)s)', + '()': ColorfulFormatter, + }, + }, + 'filters': { + 'InfoFilter': { + '()': InfoFilter, + }, + 'DebugFilter': { + '()': DebugFilter, + }, + 'WarnFilter': { + '()': WarnFilter, + }, + 'ErrorFilter': { + '()': ErrorFilter, + }, + 'CriticalFilter': { + '()': CriticalFilter, + }, + }, + 'handlers': { + 'milvus_celery_console': { + 'class': 'logging.StreamHandler', + 'formatter': 'colorful_console', + }, + 'milvus_debug_file': { + 'level': 'DEBUG', + 'filters': ['DebugFilter'], + 'class': 'logging.handlers.RotatingFileHandler', + 'formatter': 'default', + 'filename': build_log_file('debug', log_path, name, tz) + }, + 'milvus_info_file': { + 'level': 'INFO', + 'filters': ['InfoFilter'], + 'class': 'logging.handlers.RotatingFileHandler', + 'formatter': 'default', + 'filename': build_log_file('info', log_path, name, tz) + }, + 'milvus_warn_file': { + 'level': 'WARN', + 'filters': ['WarnFilter'], + 'class': 'logging.handlers.RotatingFileHandler', + 'formatter': 'default', + 'filename': build_log_file('warn', log_path, name, tz) + }, + 'milvus_error_file': { + 'level': 'ERROR', + 'filters': ['ErrorFilter'], + 'class': 'logging.handlers.RotatingFileHandler', + 'formatter': 'default', + 'filename': build_log_file('error', log_path, name, tz) + }, + 'milvus_critical_file': { + 'level': 'CRITICAL', + 'filters': ['CriticalFilter'], + 'class': 'logging.handlers.RotatingFileHandler', + 'formatter': 'default', + 'filename': build_log_file('critical', log_path, name, tz) + }, + }, + 'loggers': { + '': { + 'handlers': ['milvus_celery_console', 'milvus_info_file', 'milvus_debug_file', 'milvus_warn_file', + 'milvus_error_file', 'milvus_critical_file'], + 'level': log_level, + 'propagate': False + }, + }, + 'propagate': False, + } + + logging.config.dictConfig(LOGGING) diff --git a/shards/utils/pluginextension.py b/shards/utils/pluginextension.py new file mode 100644 index 0000000000..68413a4e55 --- /dev/null +++ b/shards/utils/pluginextension.py @@ -0,0 +1,16 @@ +import importlib.util +from pluginbase import PluginBase, PluginSource + + +class MiPluginSource(PluginSource): + def load_plugin(self, name): + plugin = super().load_plugin(name) + spec = importlib.util.spec_from_file_location(self.base.package + '.' + name, plugin.__file__) + plugin = importlib.util.module_from_spec(spec) + spec.loader.exec_module(plugin) + return plugin + + +class MiPluginBase(PluginBase): + def make_plugin_source(self, *args, **kwargs): + return MiPluginSource(self, *args, **kwargs) diff --git a/shards/utils/plugins/__init__.py b/shards/utils/plugins/__init__.py new file mode 100644 index 0000000000..633f1164a7 --- /dev/null +++ b/shards/utils/plugins/__init__.py @@ -0,0 +1,40 @@ +import os +import inspect +from functools import partial +from utils.pluginextension import MiPluginBase as PluginBase + + +class BaseMixin(object): + + def __init__(self, package_name, searchpath=None): + self.plugin_package_name = package_name + caller_path = os.path.dirname(inspect.stack()[1][1]) + get_path = partial(os.path.join, caller_path) + plugin_base = PluginBase(package=self.plugin_package_name, + searchpath=[get_path('./plugins')]) + self.class_map = {} + searchpath = searchpath if searchpath else [] + searchpath = [searchpath] if isinstance(searchpath, str) else searchpath + self.source = plugin_base.make_plugin_source(searchpath=searchpath, + identifier=self.__class__.__name__) + + for plugin_name in self.source.list_plugins(): + plugin = self.source.load_plugin(plugin_name) + plugin.setup(self) + + def on_plugin_setup(self, plugin_class): + name = getattr(plugin_class, 'name', plugin_class.__name__) + self.class_map[name.lower()] = plugin_class + + def plugin(self, name): + return self.class_map.get(name, None) + + def create(self, class_name, **kwargs): + if not class_name: + raise RuntimeError('Please specify \'{}\' class_name first!'.format(self.PLUGIN_TYPE)) + + plugin_class = self.plugin(class_name.lower()) + if not plugin_class: + raise RuntimeError('{} Plugin \'{}\' Not Installed!'.format(self.PLUGIN_TYPE, class_name)) + + return self._create(plugin_class, **kwargs)