From 9a4c732563323cd8814a11a5eda8891745e264ba Mon Sep 17 00:00:00 2001 From: "peng.xu" Date: Mon, 21 Oct 2019 16:20:29 +0800 Subject: [PATCH] fix bug in test_server --- Dockerfile | 10 - build.sh | 39 -- conftest.py | 27 -- manager.py | 28 -- mishards/.env.example | 33 -- mishards/__init__.py | 36 -- mishards/connections.py | 154 -------- mishards/db_base.py | 52 --- mishards/exception_codes.py | 10 - mishards/exception_handlers.py | 82 ---- mishards/exceptions.py | 38 -- mishards/factories.py | 54 --- mishards/grpc_utils/__init__.py | 37 -- mishards/grpc_utils/grpc_args_parser.py | 102 ----- mishards/grpc_utils/grpc_args_wrapper.py | 4 - mishards/grpc_utils/test_grpc.py | 75 ---- mishards/hash_ring.py | 150 ------- mishards/main.py | 15 - mishards/models.py | 76 ---- mishards/routings.py | 96 ----- mishards/server.py | 122 ------ mishards/service_handler.py | 475 ----------------------- mishards/settings.py | 94 ----- mishards/test_connections.py | 101 ----- mishards/test_models.py | 39 -- mishards/test_server.py | 279 ------------- mishards/utilities.py | 20 - requirements.txt | 36 -- sd/__init__.py | 28 -- sd/kubernetes_provider.py | 331 ---------------- sd/static_provider.py | 39 -- setup.cfg | 4 - start_services.yml | 45 --- tracing/__init__.py | 43 -- tracing/factory.py | 40 -- utils/__init__.py | 11 - utils/logger_helper.py | 152 -------- 37 files changed, 2977 deletions(-) delete mode 100644 Dockerfile delete mode 100755 build.sh delete mode 100644 conftest.py delete mode 100644 manager.py delete mode 100644 mishards/.env.example delete mode 100644 mishards/__init__.py delete mode 100644 mishards/connections.py delete mode 100644 mishards/db_base.py delete mode 100644 mishards/exception_codes.py delete mode 100644 mishards/exception_handlers.py delete mode 100644 mishards/exceptions.py delete mode 100644 mishards/factories.py delete mode 100644 mishards/grpc_utils/__init__.py delete mode 100644 mishards/grpc_utils/grpc_args_parser.py delete mode 100644 mishards/grpc_utils/grpc_args_wrapper.py delete mode 100644 mishards/grpc_utils/test_grpc.py delete mode 100644 mishards/hash_ring.py delete mode 100644 mishards/main.py delete mode 100644 mishards/models.py delete mode 100644 mishards/routings.py delete mode 100644 mishards/server.py delete mode 100644 mishards/service_handler.py delete mode 100644 mishards/settings.py delete mode 100644 mishards/test_connections.py delete mode 100644 mishards/test_models.py delete mode 100644 mishards/test_server.py delete mode 100644 mishards/utilities.py delete mode 100644 requirements.txt delete mode 100644 sd/__init__.py delete mode 100644 sd/kubernetes_provider.py delete mode 100644 sd/static_provider.py delete mode 100644 setup.cfg delete mode 100644 start_services.yml delete mode 100644 tracing/__init__.py delete mode 100644 tracing/factory.py delete mode 100644 utils/__init__.py delete mode 100644 utils/logger_helper.py diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 594640619e..0000000000 --- a/Dockerfile +++ /dev/null @@ -1,10 +0,0 @@ -FROM python:3.6 -RUN apt update && apt install -y \ - less \ - telnet -RUN mkdir /source -WORKDIR /source -ADD ./requirements.txt ./ -RUN pip install -r requirements.txt -COPY . . -CMD python mishards/main.py diff --git a/build.sh b/build.sh deleted file mode 100755 index fad30518f2..0000000000 --- a/build.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash - -BOLD=`tput bold` -NORMAL=`tput sgr0` -YELLOW='\033[1;33m' -ENDC='\033[0m' - -echo -e "${BOLD}MISHARDS_REGISTRY=${MISHARDS_REGISTRY}${ENDC}" - -function build_image() { - dockerfile=$1 - remote_registry=$2 - tagged=$2 - buildcmd="docker build -t ${tagged} -f ${dockerfile} ." - echo -e "${BOLD}$buildcmd${NORMAL}" - $buildcmd - pushcmd="docker push ${remote_registry}" - echo -e "${BOLD}$pushcmd${NORMAL}" - $pushcmd - echo -e "${YELLOW}${BOLD}Image: ${remote_registry}${NORMAL}${ENDC}" -} - -case "$1" in - -all) - [[ -z $MISHARDS_REGISTRY ]] && { - echo -e "${YELLOW}Error: Please set docker registry first:${ENDC}\n\t${BOLD}export MISHARDS_REGISTRY=xxxx\n${ENDC}" - exit 1 - } - - version="" - [[ ! -z $2 ]] && version=":${2}" - build_image "Dockerfile" "${MISHARDS_REGISTRY}${version}" "${MISHARDS_REGISTRY}" - ;; -*) - echo "Usage: [option...] {base | apps}" - echo "all, Usage: build.sh all [tagname|] => {docker_registry}:\${tagname}" - ;; -esac diff --git a/conftest.py b/conftest.py deleted file mode 100644 index 34e22af693..0000000000 --- a/conftest.py +++ /dev/null @@ -1,27 +0,0 @@ -import logging -import pytest -import grpc -from mishards import settings, db, create_app - -logger = logging.getLogger(__name__) - - -@pytest.fixture -def app(request): - app = create_app(settings.TestingConfig) - db.drop_all() - db.create_all() - - yield app - - db.drop_all() - - -@pytest.fixture -def started_app(app): - app.on_pre_run() - app.start(settings.SERVER_TEST_PORT) - - yield app - - app.stop() diff --git a/manager.py b/manager.py deleted file mode 100644 index 931c90ebc8..0000000000 --- a/manager.py +++ /dev/null @@ -1,28 +0,0 @@ -import fire -from mishards import db -from sqlalchemy import and_ - - -class DBHandler: - @classmethod - def create_all(cls): - db.create_all() - - @classmethod - def drop_all(cls): - db.drop_all() - - @classmethod - def fun(cls, tid): - from mishards.factories import TablesFactory, TableFilesFactory, Tables - f = db.Session.query(Tables).filter(and_( - Tables.table_id == tid, - Tables.state != Tables.TO_DELETE) - ).first() - print(f) - - # f1 = TableFilesFactory() - - -if __name__ == '__main__': - fire.Fire(DBHandler) diff --git a/mishards/.env.example b/mishards/.env.example deleted file mode 100644 index 0a23c0cf56..0000000000 --- a/mishards/.env.example +++ /dev/null @@ -1,33 +0,0 @@ -DEBUG=True - -WOSERVER=tcp://127.0.0.1:19530 -SERVER_PORT=19532 -SERVER_TEST_PORT=19888 - -SD_PROVIDER=Static - -SD_NAMESPACE=xp -SD_IN_CLUSTER=False -SD_POLL_INTERVAL=5 -SD_ROSERVER_POD_PATT=.*-ro-servers-.* -SD_LABEL_SELECTOR=tier=ro-servers - -SD_STATIC_HOSTS=127.0.0.1 -SD_STATIC_PORT=19530 - -#SQLALCHEMY_DATABASE_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4 -SQLALCHEMY_DATABASE_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False -SQL_ECHO=True - -#SQLALCHEMY_DATABASE_TEST_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4 -SQLALCHEMY_DATABASE_TEST_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False -SQL_TEST_ECHO=False - -# TRACING_TEST_TYPE=jaeger -TRACING_TYPE=jaeger -TRACING_SERVICE_NAME=fortest -TRACING_SAMPLER_TYPE=const -TRACING_SAMPLER_PARAM=1 -TRACING_LOG_PAYLOAD=True -#TRACING_SAMPLER_TYPE=probabilistic -#TRACING_SAMPLER_PARAM=0.5 diff --git a/mishards/__init__.py b/mishards/__init__.py deleted file mode 100644 index 7db3d8cb5e..0000000000 --- a/mishards/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -import logging -from mishards import settings -logger = logging.getLogger() - -from mishards.db_base import DB -db = DB() - -from mishards.server import Server -grpc_server = Server() - - -def create_app(testing_config=None): - config = testing_config if testing_config else settings.DefaultConfig - db.init_db(uri=config.SQLALCHEMY_DATABASE_URI, echo=config.SQL_ECHO) - - from mishards.connections import ConnectionMgr - connect_mgr = ConnectionMgr() - - from sd import ProviderManager - - sd_proiver_class = ProviderManager.get_provider(settings.SD_PROVIDER) - discover = sd_proiver_class(settings=settings.SD_PROVIDER_SETTINGS, conn_mgr=connect_mgr) - - from tracing.factory import TracerFactory - from mishards.grpc_utils import GrpcSpanDecorator - tracer = TracerFactory.new_tracer(config.TRACING_TYPE, settings.TracingConfig, - span_decorator=GrpcSpanDecorator()) - - from mishards.routings import RouterFactory - router = RouterFactory.new_router(config.ROUTER_CLASS_NAME, connect_mgr) - - grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, router=router, discover=discover) - - from mishards import exception_handlers - - return grpc_server diff --git a/mishards/connections.py b/mishards/connections.py deleted file mode 100644 index 618690a099..0000000000 --- a/mishards/connections.py +++ /dev/null @@ -1,154 +0,0 @@ -import logging -import threading -from functools import wraps -from milvus import Milvus - -from mishards import (settings, exceptions) -from utils import singleton - -logger = logging.getLogger(__name__) - - -class Connection: - def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs): - self.name = name - self.uri = uri - self.max_retry = max_retry - self.retried = 0 - self.conn = Milvus() - self.error_handlers = [] if not error_handlers else error_handlers - self.on_retry_func = kwargs.get('on_retry_func', None) - # self._connect() - - def __str__(self): - return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri) - - def _connect(self, metadata=None): - try: - self.conn.connect(uri=self.uri) - except Exception as e: - if not self.error_handlers: - raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata) - for handler in self.error_handlers: - handler(e, metadata=metadata) - - @property - def can_retry(self): - return self.retried < self.max_retry - - @property - def connected(self): - return self.conn.connected() - - def on_retry(self): - if self.on_retry_func: - self.on_retry_func(self) - else: - self.retried > 1 and logger.warning('{} is retrying {}'.format(self, self.retried)) - - def on_connect(self, metadata=None): - while not self.connected and self.can_retry: - self.retried += 1 - self.on_retry() - self._connect(metadata=metadata) - - if not self.can_retry and not self.connected: - raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry, - metadata=metadata)) - - self.retried = 0 - - def connect(self, func, exception_handler=None): - @wraps(func) - def inner(*args, **kwargs): - self.on_connect() - try: - return func(*args, **kwargs) - except Exception as e: - if exception_handler: - exception_handler(e) - else: - raise e - return inner - - -@singleton -class ConnectionMgr: - def __init__(self): - self.metas = {} - self.conns = {} - - @property - def conn_names(self): - return set(self.metas.keys()) - set(['WOSERVER']) - - def conn(self, name, metadata, throw=False): - c = self.conns.get(name, None) - if not c: - url = self.metas.get(name, None) - if not url: - if not throw: - return None - raise exceptions.ConnectionNotFoundError(message='Connection {} not found'.format(name), - metadata=metadata) - this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY) - threaded = { - threading.get_ident(): this_conn - } - self.conns[name] = threaded - return this_conn - - tid = threading.get_ident() - rconn = c.get(tid, None) - if not rconn: - url = self.metas.get(name, None) - if not url: - if not throw: - return None - raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name), - metadata=metadata) - this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY) - c[tid] = this_conn - return this_conn - - return rconn - - def on_new_meta(self, name, url): - logger.info('Register Connection: name={};url={}'.format(name, url)) - self.metas[name] = url - - def on_duplicate_meta(self, name, url): - if self.metas[name] == url: - return self.on_same_meta(name, url) - - return self.on_diff_meta(name, url) - - def on_same_meta(self, name, url): - # logger.warning('Register same meta: {}:{}'.format(name, url)) - pass - - def on_diff_meta(self, name, url): - logger.warning('Received {} with diff url={}'.format(name, url)) - self.metas[name] = url - self.conns[name] = {} - - def on_unregister_meta(self, name, url): - logger.info('Unregister name={};url={}'.format(name, url)) - self.conns.pop(name, None) - - def on_nonexisted_meta(self, name): - logger.warning('Non-existed meta: {}'.format(name)) - - def register(self, name, url): - meta = self.metas.get(name) - if not meta: - return self.on_new_meta(name, url) - else: - return self.on_duplicate_meta(name, url) - - def unregister(self, name): - logger.info('Unregister Connection: name={}'.format(name)) - url = self.metas.pop(name, None) - if url is None: - return self.on_nonexisted_meta(name) - return self.on_unregister_meta(name, url) diff --git a/mishards/db_base.py b/mishards/db_base.py deleted file mode 100644 index 5f2eee9ba1..0000000000 --- a/mishards/db_base.py +++ /dev/null @@ -1,52 +0,0 @@ -import logging -from sqlalchemy import create_engine -from sqlalchemy.engine.url import make_url -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy.orm.session import Session as SessionBase - -logger = logging.getLogger(__name__) - - -class LocalSession(SessionBase): - def __init__(self, db, autocommit=False, autoflush=True, **options): - self.db = db - bind = options.pop('bind', None) or db.engine - SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, **options) - - -class DB: - Model = declarative_base() - - def __init__(self, uri=None, echo=False): - self.echo = echo - uri and self.init_db(uri, echo) - self.session_factory = scoped_session(sessionmaker(class_=LocalSession, db=self)) - - def init_db(self, uri, echo=False): - url = make_url(uri) - if url.get_backend_name() == 'sqlite': - self.engine = create_engine(url) - else: - self.engine = create_engine(uri, pool_size=100, pool_recycle=5, pool_timeout=30, - pool_pre_ping=True, - echo=echo, - max_overflow=0) - self.uri = uri - self.url = url - - def __str__(self): - return ''.format(self.url.get_backend_name(), self.url.database) - - @property - def Session(self): - return self.session_factory() - - def remove_session(self): - self.session_factory.remove() - - def drop_all(self): - self.Model.metadata.drop_all(self.engine) - - def create_all(self): - self.Model.metadata.create_all(self.engine) diff --git a/mishards/exception_codes.py b/mishards/exception_codes.py deleted file mode 100644 index bdd4572dd5..0000000000 --- a/mishards/exception_codes.py +++ /dev/null @@ -1,10 +0,0 @@ -INVALID_CODE = -1 - -CONNECT_ERROR_CODE = 10001 -CONNECTTION_NOT_FOUND_CODE = 10002 -DB_ERROR_CODE = 10003 - -TABLE_NOT_FOUND_CODE = 20001 -INVALID_ARGUMENT_CODE = 20002 -INVALID_DATE_RANGE_CODE = 20003 -INVALID_TOPK_CODE = 20004 diff --git a/mishards/exception_handlers.py b/mishards/exception_handlers.py deleted file mode 100644 index c79a6db5a3..0000000000 --- a/mishards/exception_handlers.py +++ /dev/null @@ -1,82 +0,0 @@ -import logging -from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 -from mishards import grpc_server as server, exceptions - -logger = logging.getLogger(__name__) - - -def resp_handler(err, error_code): - if not isinstance(err, exceptions.BaseException): - return status_pb2.Status(error_code=error_code, reason=str(err)) - - status = status_pb2.Status(error_code=error_code, reason=err.message) - - if err.metadata is None: - return status - - resp_class = err.metadata.get('resp_class', None) - if not resp_class: - return status - - if resp_class == milvus_pb2.BoolReply: - return resp_class(status=status, bool_reply=False) - - if resp_class == milvus_pb2.VectorIds: - return resp_class(status=status, vector_id_array=[]) - - if resp_class == milvus_pb2.TopKQueryResultList: - return resp_class(status=status, topk_query_result=[]) - - if resp_class == milvus_pb2.TableRowCount: - return resp_class(status=status, table_row_count=-1) - - if resp_class == milvus_pb2.TableName: - return resp_class(status=status, table_name=[]) - - if resp_class == milvus_pb2.StringReply: - return resp_class(status=status, string_reply='') - - if resp_class == milvus_pb2.TableSchema: - return milvus_pb2.TableSchema( - status=status - ) - - if resp_class == milvus_pb2.IndexParam: - return milvus_pb2.IndexParam( - table_name=milvus_pb2.TableName( - status=status - ) - ) - - status.error_code = status_pb2.UNEXPECTED_ERROR - return status - - -@server.errorhandler(exceptions.TableNotFoundError) -def TableNotFoundErrorHandler(err): - logger.error(err) - return resp_handler(err, status_pb2.TABLE_NOT_EXISTS) - - -@server.errorhandler(exceptions.InvalidTopKError) -def InvalidTopKErrorHandler(err): - logger.error(err) - return resp_handler(err, status_pb2.ILLEGAL_TOPK) - - -@server.errorhandler(exceptions.InvalidArgumentError) -def InvalidArgumentErrorHandler(err): - logger.error(err) - return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT) - - -@server.errorhandler(exceptions.DBError) -def DBErrorHandler(err): - logger.error(err) - return resp_handler(err, status_pb2.UNEXPECTED_ERROR) - - -@server.errorhandler(exceptions.InvalidRangeError) -def InvalidArgumentErrorHandler(err): - logger.error(err) - return resp_handler(err, status_pb2.ILLEGAL_RANGE) diff --git a/mishards/exceptions.py b/mishards/exceptions.py deleted file mode 100644 index 72839f88d2..0000000000 --- a/mishards/exceptions.py +++ /dev/null @@ -1,38 +0,0 @@ -import mishards.exception_codes as codes - - -class BaseException(Exception): - code = codes.INVALID_CODE - message = 'BaseException' - - def __init__(self, message='', metadata=None): - self.message = self.__class__.__name__ if not message else message - self.metadata = metadata - - -class ConnectionConnectError(BaseException): - code = codes.CONNECT_ERROR_CODE - - -class ConnectionNotFoundError(BaseException): - code = codes.CONNECTTION_NOT_FOUND_CODE - - -class DBError(BaseException): - code = codes.DB_ERROR_CODE - - -class TableNotFoundError(BaseException): - code = codes.TABLE_NOT_FOUND_CODE - - -class InvalidTopKError(BaseException): - code = codes.INVALID_TOPK_CODE - - -class InvalidArgumentError(BaseException): - code = codes.INVALID_ARGUMENT_CODE - - -class InvalidRangeError(BaseException): - code = codes.INVALID_DATE_RANGE_CODE diff --git a/mishards/factories.py b/mishards/factories.py deleted file mode 100644 index 52c0253b39..0000000000 --- a/mishards/factories.py +++ /dev/null @@ -1,54 +0,0 @@ -import time -import datetime -import random -import factory -from factory.alchemy import SQLAlchemyModelFactory -from faker import Faker -from faker.providers import BaseProvider - -from milvus.client.types import MetricType -from mishards import db -from mishards.models import Tables, TableFiles - - -class FakerProvider(BaseProvider): - def this_date(self): - t = datetime.datetime.today() - return (t.year - 1900) * 10000 + (t.month - 1) * 100 + t.day - - -factory.Faker.add_provider(FakerProvider) - - -class TablesFactory(SQLAlchemyModelFactory): - class Meta: - model = Tables - sqlalchemy_session = db.session_factory - sqlalchemy_session_persistence = 'commit' - - id = factory.Faker('random_number', digits=16, fix_len=True) - table_id = factory.Faker('uuid4') - state = factory.Faker('random_element', elements=(0, 1)) - dimension = factory.Faker('random_element', elements=(256, 512)) - created_on = int(time.time()) - index_file_size = 0 - engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3)) - metric_type = factory.Faker('random_element', elements=(MetricType.L2, MetricType.IP)) - nlist = 16384 - - -class TableFilesFactory(SQLAlchemyModelFactory): - class Meta: - model = TableFiles - sqlalchemy_session = db.session_factory - sqlalchemy_session_persistence = 'commit' - - id = factory.Faker('random_number', digits=16, fix_len=True) - table = factory.SubFactory(TablesFactory) - engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3)) - file_id = factory.Faker('uuid4') - file_type = factory.Faker('random_element', elements=(0, 1, 2, 3, 4)) - file_size = factory.Faker('random_number') - updated_time = int(time.time()) - created_on = int(time.time()) - date = factory.Faker('this_date') diff --git a/mishards/grpc_utils/__init__.py b/mishards/grpc_utils/__init__.py deleted file mode 100644 index f5225b2a66..0000000000 --- a/mishards/grpc_utils/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -from grpc_opentracing import SpanDecorator -from milvus.grpc_gen import status_pb2 - - -class GrpcSpanDecorator(SpanDecorator): - def __call__(self, span, rpc_info): - status = None - if not rpc_info.response: - return - if isinstance(rpc_info.response, status_pb2.Status): - status = rpc_info.response - else: - try: - status = rpc_info.response.status - except Exception as e: - status = status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, - reason='Should not happen') - - if status.error_code == 0: - return - error_log = {'event': 'error', - 'request': rpc_info.request, - 'response': rpc_info.response - } - span.set_tag('error', True) - span.log_kv(error_log) - - -def mark_grpc_method(func): - setattr(func, 'grpc_method', True) - return func - - -def is_grpc_method(func): - if not func: - return False - return getattr(func, 'grpc_method', False) diff --git a/mishards/grpc_utils/grpc_args_parser.py b/mishards/grpc_utils/grpc_args_parser.py deleted file mode 100644 index 039299803d..0000000000 --- a/mishards/grpc_utils/grpc_args_parser.py +++ /dev/null @@ -1,102 +0,0 @@ -from milvus import Status -from functools import wraps - - -def error_status(func): - @wraps(func) - def inner(*args, **kwargs): - try: - results = func(*args, **kwargs) - except Exception as e: - return Status(code=Status.UNEXPECTED_ERROR, message=str(e)), None - - return Status(code=0, message="Success"), results - - return inner - - -class GrpcArgsParser(object): - - @classmethod - @error_status - def parse_proto_TableSchema(cls, param): - _table_schema = { - 'status': param.status, - 'table_name': param.table_name, - 'dimension': param.dimension, - 'index_file_size': param.index_file_size, - 'metric_type': param.metric_type - } - - return _table_schema - - @classmethod - @error_status - def parse_proto_TableName(cls, param): - return param.table_name - - @classmethod - @error_status - def parse_proto_Index(cls, param): - _index = { - 'index_type': param.index_type, - 'nlist': param.nlist - } - - return _index - - @classmethod - @error_status - def parse_proto_IndexParam(cls, param): - _table_name = param.table_name - _status, _index = cls.parse_proto_Index(param.index) - - if not _status.OK(): - raise Exception("Argument parse error") - - return _table_name, _index - - @classmethod - @error_status - def parse_proto_Command(cls, param): - _cmd = param.cmd - - return _cmd - - @classmethod - @error_status - def parse_proto_Range(cls, param): - _start_value = param.start_value - _end_value = param.end_value - - return _start_value, _end_value - - @classmethod - @error_status - def parse_proto_RowRecord(cls, param): - return list(param.vector_data) - - @classmethod - @error_status - def parse_proto_SearchParam(cls, param): - _table_name = param.table_name - _topk = param.topk - _nprobe = param.nprobe - _status, _range = cls.parse_proto_Range(param.query_range_array) - - if not _status.OK(): - raise Exception("Argument parse error") - - _row_record = param.query_record_array - - return _table_name, _row_record, _range, _topk - - @classmethod - @error_status - def parse_proto_DeleteByRangeParam(cls, param): - _table_name = param.table_name - _range = param.range - _start_value = _range.start_value - _end_value = _range.end_value - - return _table_name, _start_value, _end_value diff --git a/mishards/grpc_utils/grpc_args_wrapper.py b/mishards/grpc_utils/grpc_args_wrapper.py deleted file mode 100644 index 7447dbd995..0000000000 --- a/mishards/grpc_utils/grpc_args_wrapper.py +++ /dev/null @@ -1,4 +0,0 @@ -# class GrpcArgsWrapper(object): - -# @classmethod -# def proto_TableName(cls): diff --git a/mishards/grpc_utils/test_grpc.py b/mishards/grpc_utils/test_grpc.py deleted file mode 100644 index 9af09e5d0d..0000000000 --- a/mishards/grpc_utils/test_grpc.py +++ /dev/null @@ -1,75 +0,0 @@ -import logging -import opentracing -from mishards.grpc_utils import GrpcSpanDecorator, is_grpc_method -from milvus.grpc_gen import status_pb2, milvus_pb2 - -logger = logging.getLogger(__name__) - - -class FakeTracer(opentracing.Tracer): - pass - - -class FakeSpan(opentracing.Span): - def __init__(self, context, tracer, **kwargs): - super(FakeSpan, self).__init__(tracer, context) - self.reset() - - def set_tag(self, key, value): - self.tags.append({key: value}) - - def log_kv(self, key_values, timestamp=None): - self.logs.append(key_values) - - def reset(self): - self.tags = [] - self.logs = [] - - -class FakeRpcInfo: - def __init__(self, request, response): - self.request = request - self.response = response - - -class TestGrpcUtils: - def test_span_deco(self): - request = 'request' - OK = status_pb2.Status(error_code=status_pb2.SUCCESS, reason='Success') - response = OK - rpc_info = FakeRpcInfo(request=request, response=response) - span = FakeSpan(context=None, tracer=FakeTracer()) - span_deco = GrpcSpanDecorator() - span_deco(span, rpc_info) - assert len(span.logs) == 0 - assert len(span.tags) == 0 - - response = milvus_pb2.BoolReply(status=OK, bool_reply=False) - rpc_info = FakeRpcInfo(request=request, response=response) - span = FakeSpan(context=None, tracer=FakeTracer()) - span_deco = GrpcSpanDecorator() - span_deco(span, rpc_info) - assert len(span.logs) == 0 - assert len(span.tags) == 0 - - response = 1 - rpc_info = FakeRpcInfo(request=request, response=response) - span = FakeSpan(context=None, tracer=FakeTracer()) - span_deco = GrpcSpanDecorator() - span_deco(span, rpc_info) - assert len(span.logs) == 1 - assert len(span.tags) == 1 - - response = 0 - rpc_info = FakeRpcInfo(request=request, response=response) - span = FakeSpan(context=None, tracer=FakeTracer()) - span_deco = GrpcSpanDecorator() - span_deco(span, rpc_info) - assert len(span.logs) == 0 - assert len(span.tags) == 0 - - def test_is_grpc_method(self): - target = 1 - assert not is_grpc_method(target) - target = None - assert not is_grpc_method(target) diff --git a/mishards/hash_ring.py b/mishards/hash_ring.py deleted file mode 100644 index a97f3f580e..0000000000 --- a/mishards/hash_ring.py +++ /dev/null @@ -1,150 +0,0 @@ -import math -import sys -from bisect import bisect - -if sys.version_info >= (2, 5): - import hashlib - md5_constructor = hashlib.md5 -else: - import md5 - md5_constructor = md5.new - - -class HashRing(object): - def __init__(self, nodes=None, weights=None): - """`nodes` is a list of objects that have a proper __str__ representation. - `weights` is dictionary that sets weights to the nodes. The default - weight is that all nodes are equal. - """ - self.ring = dict() - self._sorted_keys = [] - - self.nodes = nodes - - if not weights: - weights = {} - self.weights = weights - - self._generate_circle() - - def _generate_circle(self): - """Generates the circle. - """ - total_weight = 0 - for node in self.nodes: - total_weight += self.weights.get(node, 1) - - for node in self.nodes: - weight = 1 - - if node in self.weights: - weight = self.weights.get(node) - - factor = math.floor((40 * len(self.nodes) * weight) / total_weight) - - for j in range(0, int(factor)): - b_key = self._hash_digest('%s-%s' % (node, j)) - - for i in range(0, 3): - key = self._hash_val(b_key, lambda x: x + i * 4) - self.ring[key] = node - self._sorted_keys.append(key) - - self._sorted_keys.sort() - - def get_node(self, string_key): - """Given a string key a corresponding node in the hash ring is returned. - - If the hash ring is empty, `None` is returned. - """ - pos = self.get_node_pos(string_key) - if pos is None: - return None - return self.ring[self._sorted_keys[pos]] - - def get_node_pos(self, string_key): - """Given a string key a corresponding node in the hash ring is returned - along with it's position in the ring. - - If the hash ring is empty, (`None`, `None`) is returned. - """ - if not self.ring: - return None - - key = self.gen_key(string_key) - - nodes = self._sorted_keys - pos = bisect(nodes, key) - - if pos == len(nodes): - return 0 - else: - return pos - - def iterate_nodes(self, string_key, distinct=True): - """Given a string key it returns the nodes as a generator that can hold the key. - - The generator iterates one time through the ring - starting at the correct position. - - if `distinct` is set, then the nodes returned will be unique, - i.e. no virtual copies will be returned. - """ - if not self.ring: - yield None, None - - returned_values = set() - - def distinct_filter(value): - if str(value) not in returned_values: - returned_values.add(str(value)) - return value - - pos = self.get_node_pos(string_key) - for key in self._sorted_keys[pos:]: - val = distinct_filter(self.ring[key]) - if val: - yield val - - for i, key in enumerate(self._sorted_keys): - if i < pos: - val = distinct_filter(self.ring[key]) - if val: - yield val - - def gen_key(self, key): - """Given a string key it returns a long value, - this long value represents a place on the hash ring. - - md5 is currently used because it mixes well. - """ - b_key = self._hash_digest(key) - return self._hash_val(b_key, lambda x: x) - - def _hash_val(self, b_key, entry_fn): - return (b_key[entry_fn(3)] << 24) | (b_key[entry_fn(2)] << 16) | ( - b_key[entry_fn(1)] << 8) | b_key[entry_fn(0)] - - def _hash_digest(self, key): - m = md5_constructor() - key = key.encode() - m.update(key) - return m.digest() - - -if __name__ == '__main__': - from collections import defaultdict - servers = [ - '192.168.0.246:11212', '192.168.0.247:11212', '192.168.0.248:11212', - '192.168.0.249:11212' - ] - - ring = HashRing(servers) - keys = ['{}'.format(i) for i in range(100)] - mapped = defaultdict(list) - for k in keys: - server = ring.get_node(k) - mapped[server].append(k) - - for k, v in mapped.items(): - print(k, v) diff --git a/mishards/main.py b/mishards/main.py deleted file mode 100644 index c0d142607b..0000000000 --- a/mishards/main.py +++ /dev/null @@ -1,15 +0,0 @@ -import os -import sys -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from mishards import (settings, create_app) - - -def main(): - server = create_app(settings.DefaultConfig) - server.run(port=settings.SERVER_PORT) - return 0 - - -if __name__ == '__main__': - sys.exit(main()) diff --git a/mishards/models.py b/mishards/models.py deleted file mode 100644 index 4b6c8f9ef4..0000000000 --- a/mishards/models.py +++ /dev/null @@ -1,76 +0,0 @@ -import logging -from sqlalchemy import (Integer, Boolean, Text, - String, BigInteger, and_, or_, - Column) -from sqlalchemy.orm import relationship, backref - -from mishards import db - -logger = logging.getLogger(__name__) - - -class TableFiles(db.Model): - FILE_TYPE_NEW = 0 - FILE_TYPE_RAW = 1 - FILE_TYPE_TO_INDEX = 2 - FILE_TYPE_INDEX = 3 - FILE_TYPE_TO_DELETE = 4 - FILE_TYPE_NEW_MERGE = 5 - FILE_TYPE_NEW_INDEX = 6 - FILE_TYPE_BACKUP = 7 - - __tablename__ = 'TableFiles' - - id = Column(BigInteger, primary_key=True, autoincrement=True) - table_id = Column(String(50)) - engine_type = Column(Integer) - file_id = Column(String(50)) - file_type = Column(Integer) - file_size = Column(Integer, default=0) - row_count = Column(Integer, default=0) - updated_time = Column(BigInteger) - created_on = Column(BigInteger) - date = Column(Integer) - - table = relationship( - 'Tables', - primaryjoin='and_(foreign(TableFiles.table_id) == Tables.table_id)', - backref=backref('files', uselist=True, lazy='dynamic') - ) - - -class Tables(db.Model): - TO_DELETE = 1 - NORMAL = 0 - - __tablename__ = 'Tables' - - id = Column(BigInteger, primary_key=True, autoincrement=True) - table_id = Column(String(50), unique=True) - state = Column(Integer) - dimension = Column(Integer) - created_on = Column(Integer) - flag = Column(Integer, default=0) - index_file_size = Column(Integer) - engine_type = Column(Integer) - nlist = Column(Integer) - metric_type = Column(Integer) - - def files_to_search(self, date_range=None): - cond = or_( - TableFiles.file_type == TableFiles.FILE_TYPE_RAW, - TableFiles.file_type == TableFiles.FILE_TYPE_TO_INDEX, - TableFiles.file_type == TableFiles.FILE_TYPE_INDEX, - ) - if date_range: - cond = and_( - cond, - or_( - and_(TableFiles.date >= d[0], TableFiles.date < d[1]) for d in date_range - ) - ) - - files = self.files.filter(cond) - - logger.debug('DATE_RANGE: {}'.format(date_range)) - return files diff --git a/mishards/routings.py b/mishards/routings.py deleted file mode 100644 index 823972726f..0000000000 --- a/mishards/routings.py +++ /dev/null @@ -1,96 +0,0 @@ -import logging -from sqlalchemy import exc as sqlalchemy_exc -from sqlalchemy import and_ - -from mishards import exceptions, db -from mishards.hash_ring import HashRing -from mishards.models import Tables - -logger = logging.getLogger(__name__) - - -class RouteManager: - ROUTER_CLASSES = {} - - @classmethod - def register_router_class(cls, target): - name = target.__dict__.get('NAME', None) - name = name if name else target.__class__.__name__ - cls.ROUTER_CLASSES[name] = target - return target - - @classmethod - def get_router_class(cls, name): - return cls.ROUTER_CLASSES.get(name, None) - - -class RouterFactory: - @classmethod - def new_router(cls, name, conn_mgr, **kwargs): - router_class = RouteManager.get_router_class(name) - assert router_class - return router_class(conn_mgr, **kwargs) - - -class RouterMixin: - def __init__(self, conn_mgr): - self.conn_mgr = conn_mgr - - def routing(self, table_name, metadata=None, **kwargs): - raise NotImplemented() - - def connection(self, metadata=None): - conn = self.conn_mgr.conn('WOSERVER', metadata=metadata) - if conn: - conn.on_connect(metadata=metadata) - return conn.conn - - def query_conn(self, name, metadata=None): - conn = self.conn_mgr.conn(name, metadata=metadata) - if not conn: - raise exceptions.ConnectionNotFoundError(name, metadata=metadata) - conn.on_connect(metadata=metadata) - return conn.conn - - -@RouteManager.register_router_class -class FileBasedHashRingRouter(RouterMixin): - NAME = 'FileBasedHashRingRouter' - - def __init__(self, conn_mgr, **kwargs): - super(FileBasedHashRingRouter, self).__init__(conn_mgr) - - def routing(self, table_name, metadata=None, **kwargs): - range_array = kwargs.pop('range_array', None) - return self._route(table_name, range_array, metadata, **kwargs) - - def _route(self, table_name, range_array, metadata=None, **kwargs): - # PXU TODO: Implement Thread-local Context - # PXU TODO: Session life mgt - try: - table = db.Session.query(Tables).filter( - and_(Tables.table_id == table_name, - Tables.state != Tables.TO_DELETE)).first() - except sqlalchemy_exc.SQLAlchemyError as e: - raise exceptions.DBError(message=str(e), metadata=metadata) - - if not table: - raise exceptions.TableNotFoundError(table_name, metadata=metadata) - files = table.files_to_search(range_array) - db.remove_session() - - servers = self.conn_mgr.conn_names - logger.info('Available servers: {}'.format(servers)) - - ring = HashRing(servers) - - routing = {} - - for f in files: - target_host = ring.get_node(str(f.id)) - sub = routing.get(target_host, None) - if not sub: - routing[target_host] = {'table_id': table_name, 'file_ids': []} - routing[target_host]['file_ids'].append(str(f.id)) - - return routing diff --git a/mishards/server.py b/mishards/server.py deleted file mode 100644 index 599a00e455..0000000000 --- a/mishards/server.py +++ /dev/null @@ -1,122 +0,0 @@ -import logging -import grpc -import time -import socket -import inspect -from urllib.parse import urlparse -from functools import wraps -from concurrent import futures -from grpc._cython import cygrpc -from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server -from mishards.grpc_utils import is_grpc_method -from mishards.service_handler import ServiceHandler -from mishards import settings - -logger = logging.getLogger(__name__) - - -class Server: - def __init__(self): - self.pre_run_handlers = set() - self.grpc_methods = set() - self.error_handlers = {} - self.exit_flag = False - - def init_app(self, - conn_mgr, - tracer, - router, - discover, - port=19530, - max_workers=10, - **kwargs): - self.port = int(port) - self.conn_mgr = conn_mgr - self.tracer = tracer - self.router = router - self.discover = discover - - self.server_impl = grpc.server( - thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers), - options=[(cygrpc.ChannelArgKey.max_send_message_length, -1), - (cygrpc.ChannelArgKey.max_receive_message_length, -1)]) - - self.server_impl = self.tracer.decorate(self.server_impl) - - self.register_pre_run_handler(self.pre_run_handler) - - def pre_run_handler(self): - woserver = settings.WOSERVER - url = urlparse(woserver) - ip = socket.gethostbyname(url.hostname) - socket.inet_pton(socket.AF_INET, ip) - self.conn_mgr.register( - 'WOSERVER', '{}://{}:{}'.format(url.scheme, ip, url.port or 80)) - - def register_pre_run_handler(self, func): - logger.info('Regiterring {} into server pre_run_handlers'.format(func)) - self.pre_run_handlers.add(func) - return func - - def wrap_method_with_errorhandler(self, func): - @wraps(func) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except Exception as e: - if e.__class__ in self.error_handlers: - return self.error_handlers[e.__class__](e) - raise - - return wrapper - - def errorhandler(self, exception): - if inspect.isclass(exception) and issubclass(exception, Exception): - - def wrapper(func): - self.error_handlers[exception] = func - return func - - return wrapper - return exception - - def on_pre_run(self): - for handler in self.pre_run_handlers: - handler() - self.discover.start() - - def start(self, port=None): - handler_class = self.decorate_handler(ServiceHandler) - add_MilvusServiceServicer_to_server( - handler_class(tracer=self.tracer, - router=self.router), self.server_impl) - self.server_impl.add_insecure_port("[::]:{}".format( - str(port or self.port))) - self.server_impl.start() - - def run(self, port): - logger.info('Milvus server start ......') - port = port or self.port - self.on_pre_run() - - self.start(port) - logger.info('Listening on port {}'.format(port)) - - try: - while not self.exit_flag: - time.sleep(5) - except KeyboardInterrupt: - self.stop() - - def stop(self): - logger.info('Server is shuting down ......') - self.exit_flag = True - self.server_impl.stop(0) - self.tracer.close() - logger.info('Server is closed') - - def decorate_handler(self, handler): - for key, attr in handler.__dict__.items(): - if is_grpc_method(attr): - setattr(handler, key, self.wrap_method_with_errorhandler(attr)) - return handler diff --git a/mishards/service_handler.py b/mishards/service_handler.py deleted file mode 100644 index 5e91c14f14..0000000000 --- a/mishards/service_handler.py +++ /dev/null @@ -1,475 +0,0 @@ -import logging -import time -import datetime -from collections import defaultdict - -import multiprocessing -from concurrent.futures import ThreadPoolExecutor -from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 -from milvus.grpc_gen.milvus_pb2 import TopKQueryResult -from milvus.client.abstract import Range -from milvus.client import types as Types - -from mishards import (db, settings, exceptions) -from mishards.grpc_utils import mark_grpc_method -from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser -from mishards import utilities - -logger = logging.getLogger(__name__) - - -class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): - MAX_NPROBE = 2048 - MAX_TOPK = 2048 - - def __init__(self, tracer, router, max_workers=multiprocessing.cpu_count(), **kwargs): - self.table_meta = {} - self.error_handlers = {} - self.tracer = tracer - self.router = router - self.max_workers = max_workers - - def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs): - status = status_pb2.Status(error_code=status_pb2.SUCCESS, - reason="Success") - if not files_n_topk_results: - return status, [] - - request_results = defaultdict(list) - - calc_time = time.time() - for files_collection in files_n_topk_results: - if isinstance(files_collection, tuple): - status, _ = files_collection - return status, [] - for request_pos, each_request_results in enumerate( - files_collection.topk_query_result): - request_results[request_pos].extend( - each_request_results.query_result_arrays) - request_results[request_pos] = sorted( - request_results[request_pos], - key=lambda x: x.distance, - reverse=reverse)[:topk] - - calc_time = time.time() - calc_time - logger.info('Merge takes {}'.format(calc_time)) - - results = sorted(request_results.items()) - topk_query_result = [] - - for result in results: - query_result = TopKQueryResult(query_result_arrays=result[1]) - topk_query_result.append(query_result) - - return status, topk_query_result - - def _do_query(self, - context, - table_id, - table_meta, - vectors, - topk, - nprobe, - range_array=None, - **kwargs): - metadata = kwargs.get('metadata', None) - range_array = [ - utilities.range_to_date(r, metadata=metadata) for r in range_array - ] if range_array else None - - routing = {} - p_span = None if self.tracer.empty else context.get_active_span( - ).context - with self.tracer.start_span('get_routing', child_of=p_span): - routing = self.router.routing(table_id, - range_array=range_array, - metadata=metadata) - logger.info('Routing: {}'.format(routing)) - - metadata = kwargs.get('metadata', None) - - rs = [] - all_topk_results = [] - - def search(addr, query_params, vectors, topk, nprobe, **kwargs): - logger.info( - 'Send Search Request: addr={};params={};nq={};topk={};nprobe={}' - .format(addr, query_params, len(vectors), topk, nprobe)) - - conn = self.router.query_conn(addr, metadata=metadata) - start = time.time() - span = kwargs.get('span', None) - span = span if span else (None if self.tracer.empty else - context.get_active_span().context) - - with self.tracer.start_span('search_{}'.format(addr), - child_of=span): - ret = conn.search_vectors_in_files( - table_name=query_params['table_id'], - file_ids=query_params['file_ids'], - query_records=vectors, - top_k=topk, - nprobe=nprobe, - lazy_=True) - end = time.time() - logger.info('search_vectors_in_files takes: {}'.format(end - start)) - - all_topk_results.append(ret) - - with self.tracer.start_span('do_search', child_of=p_span) as span: - with ThreadPoolExecutor(max_workers=self.max_workers) as pool: - for addr, params in routing.items(): - res = pool.submit(search, - addr, - params, - vectors, - topk, - nprobe, - span=span) - rs.append(res) - - for res in rs: - res.result() - - reverse = table_meta.metric_type == Types.MetricType.IP - with self.tracer.start_span('do_merge', child_of=p_span): - return self._do_merge(all_topk_results, - topk, - reverse=reverse, - metadata=metadata) - - def _create_table(self, table_schema): - return self.router.connection().create_table(table_schema) - - @mark_grpc_method - def CreateTable(self, request, context): - _status, _table_schema = Parser.parse_proto_TableSchema(request) - - if not _status.OK(): - return status_pb2.Status(error_code=_status.code, - reason=_status.message) - - logger.info('CreateTable {}'.format(_table_schema['table_name'])) - - _status = self._create_table(_table_schema) - - return status_pb2.Status(error_code=_status.code, - reason=_status.message) - - def _has_table(self, table_name, metadata=None): - return self.router.connection(metadata=metadata).has_table(table_name) - - @mark_grpc_method - def HasTable(self, request, context): - _status, _table_name = Parser.parse_proto_TableName(request) - - if not _status.OK(): - return milvus_pb2.BoolReply(status=status_pb2.Status( - error_code=_status.code, reason=_status.message), - bool_reply=False) - - logger.info('HasTable {}'.format(_table_name)) - - _status, _bool = self._has_table(_table_name, - metadata={'resp_class': milvus_pb2.BoolReply}) - - return milvus_pb2.BoolReply(status=status_pb2.Status( - error_code=_status.code, reason=_status.message), - bool_reply=_bool) - - def _delete_table(self, table_name): - return self.router.connection().delete_table(table_name) - - @mark_grpc_method - def DropTable(self, request, context): - _status, _table_name = Parser.parse_proto_TableName(request) - - if not _status.OK(): - return status_pb2.Status(error_code=_status.code, - reason=_status.message) - - logger.info('DropTable {}'.format(_table_name)) - - _status = self._delete_table(_table_name) - - return status_pb2.Status(error_code=_status.code, - reason=_status.message) - - def _create_index(self, table_name, index): - return self.router.connection().create_index(table_name, index) - - @mark_grpc_method - def CreateIndex(self, request, context): - _status, unpacks = Parser.parse_proto_IndexParam(request) - - if not _status.OK(): - return status_pb2.Status(error_code=_status.code, - reason=_status.message) - - _table_name, _index = unpacks - - logger.info('CreateIndex {}'.format(_table_name)) - - # TODO: interface create_table incompleted - _status = self._create_index(_table_name, _index) - - return status_pb2.Status(error_code=_status.code, - reason=_status.message) - - def _add_vectors(self, param, metadata=None): - return self.router.connection(metadata=metadata).add_vectors( - None, None, insert_param=param) - - @mark_grpc_method - def Insert(self, request, context): - logger.info('Insert') - # TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array' - _status, _ids = self._add_vectors( - metadata={'resp_class': milvus_pb2.VectorIds}, param=request) - return milvus_pb2.VectorIds(status=status_pb2.Status( - error_code=_status.code, reason=_status.message), - vector_id_array=_ids) - - @mark_grpc_method - def Search(self, request, context): - - table_name = request.table_name - - topk = request.topk - nprobe = request.nprobe - - logger.info('Search {}: topk={} nprobe={}'.format( - table_name, topk, nprobe)) - - metadata = {'resp_class': milvus_pb2.TopKQueryResultList} - - if nprobe > self.MAX_NPROBE or nprobe <= 0: - raise exceptions.InvalidArgumentError( - message='Invalid nprobe: {}'.format(nprobe), metadata=metadata) - - if topk > self.MAX_TOPK or topk <= 0: - raise exceptions.InvalidTopKError( - message='Invalid topk: {}'.format(topk), metadata=metadata) - - table_meta = self.table_meta.get(table_name, None) - - if not table_meta: - status, info = self.router.connection( - metadata=metadata).describe_table(table_name) - if not status.OK(): - raise exceptions.TableNotFoundError(table_name, - metadata=metadata) - - self.table_meta[table_name] = info - table_meta = info - - start = time.time() - - query_record_array = [] - - for query_record in request.query_record_array: - query_record_array.append(list(query_record.vector_data)) - - query_range_array = [] - for query_range in request.query_range_array: - query_range_array.append( - Range(query_range.start_value, query_range.end_value)) - - status, results = self._do_query(context, - table_name, - table_meta, - query_record_array, - topk, - nprobe, - query_range_array, - metadata=metadata) - - now = time.time() - logger.info('SearchVector takes: {}'.format(now - start)) - - topk_result_list = milvus_pb2.TopKQueryResultList( - status=status_pb2.Status(error_code=status.error_code, - reason=status.reason), - topk_query_result=results) - return topk_result_list - - @mark_grpc_method - def SearchInFiles(self, request, context): - raise NotImplemented() - - def _describe_table(self, table_name, metadata=None): - return self.router.connection(metadata=metadata).describe_table(table_name) - - @mark_grpc_method - def DescribeTable(self, request, context): - _status, _table_name = Parser.parse_proto_TableName(request) - - if not _status.OK(): - return milvus_pb2.TableSchema(status=status_pb2.Status( - error_code=_status.code, reason=_status.message), ) - - metadata = {'resp_class': milvus_pb2.TableSchema} - - logger.info('DescribeTable {}'.format(_table_name)) - _status, _table = self._describe_table(metadata=metadata, - table_name=_table_name) - - if _status.OK(): - return milvus_pb2.TableSchema( - table_name=_table_name, - index_file_size=_table.index_file_size, - dimension=_table.dimension, - metric_type=_table.metric_type, - status=status_pb2.Status(error_code=_status.code, - reason=_status.message), - ) - - return milvus_pb2.TableSchema( - table_name=_table_name, - status=status_pb2.Status(error_code=_status.code, - reason=_status.message), - ) - - def _count_table(self, table_name, metadata=None): - return self.router.connection( - metadata=metadata).get_table_row_count(table_name) - - @mark_grpc_method - def CountTable(self, request, context): - _status, _table_name = Parser.parse_proto_TableName(request) - - if not _status.OK(): - status = status_pb2.Status(error_code=_status.code, - reason=_status.message) - - return milvus_pb2.TableRowCount(status=status) - - logger.info('CountTable {}'.format(_table_name)) - - metadata = {'resp_class': milvus_pb2.TableRowCount} - _status, _count = self._count_table(_table_name, metadata=metadata) - - return milvus_pb2.TableRowCount( - status=status_pb2.Status(error_code=_status.code, - reason=_status.message), - table_row_count=_count if isinstance(_count, int) else -1) - - def _get_server_version(self, metadata=None): - return self.router.connection(metadata=metadata).server_version() - - @mark_grpc_method - def Cmd(self, request, context): - _status, _cmd = Parser.parse_proto_Command(request) - logger.info('Cmd: {}'.format(_cmd)) - - if not _status.OK(): - return milvus_pb2.StringReply(status=status_pb2.Status( - error_code=_status.code, reason=_status.message)) - - metadata = {'resp_class': milvus_pb2.StringReply} - - if _cmd == 'version': - _status, _reply = self._get_server_version(metadata=metadata) - else: - _status, _reply = self.router.connection( - metadata=metadata).server_status() - - return milvus_pb2.StringReply(status=status_pb2.Status( - error_code=_status.code, reason=_status.message), - string_reply=_reply) - - def _show_tables(self, metadata=None): - return self.router.connection(metadata=metadata).show_tables() - - @mark_grpc_method - def ShowTables(self, request, context): - logger.info('ShowTables') - metadata = {'resp_class': milvus_pb2.TableName} - _status, _results = self._show_tables(metadata=metadata) - - return milvus_pb2.TableNameList(status=status_pb2.Status( - error_code=_status.code, reason=_status.message), - table_names=_results) - - def _delete_by_range(self, table_name, start_date, end_date): - return self.router.connection().delete_vectors_by_range(table_name, - start_date, - end_date) - - @mark_grpc_method - def DeleteByRange(self, request, context): - _status, unpacks = \ - Parser.parse_proto_DeleteByRangeParam(request) - - if not _status.OK(): - return status_pb2.Status(error_code=_status.code, - reason=_status.message) - - _table_name, _start_date, _end_date = unpacks - - logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date, - _end_date)) - _status = self._delete_by_range(_table_name, _start_date, _end_date) - return status_pb2.Status(error_code=_status.code, - reason=_status.message) - - def _preload_table(self, table_name): - return self.router.connection().preload_table(table_name) - - @mark_grpc_method - def PreloadTable(self, request, context): - _status, _table_name = Parser.parse_proto_TableName(request) - - if not _status.OK(): - return status_pb2.Status(error_code=_status.code, - reason=_status.message) - - logger.info('PreloadTable {}'.format(_table_name)) - _status = self._preload_table(_table_name) - return status_pb2.Status(error_code=_status.code, - reason=_status.message) - - def _describe_index(self, table_name, metadata=None): - return self.router.connection(metadata=metadata).describe_index(table_name) - - @mark_grpc_method - def DescribeIndex(self, request, context): - _status, _table_name = Parser.parse_proto_TableName(request) - - if not _status.OK(): - return milvus_pb2.IndexParam(status=status_pb2.Status( - error_code=_status.code, reason=_status.message)) - - metadata = {'resp_class': milvus_pb2.IndexParam} - - logger.info('DescribeIndex {}'.format(_table_name)) - _status, _index_param = self._describe_index(table_name=_table_name, - metadata=metadata) - - if not _index_param: - return milvus_pb2.IndexParam(status=status_pb2.Status( - error_code=_status.code, reason=_status.message)) - - _index = milvus_pb2.Index(index_type=_index_param._index_type, - nlist=_index_param._nlist) - - return milvus_pb2.IndexParam(status=status_pb2.Status( - error_code=_status.code, reason=_status.message), - table_name=_table_name, - index=_index) - - def _drop_index(self, table_name): - return self.router.connection().drop_index(table_name) - - @mark_grpc_method - def DropIndex(self, request, context): - _status, _table_name = Parser.parse_proto_TableName(request) - - if not _status.OK(): - return status_pb2.Status(error_code=_status.code, - reason=_status.message) - - logger.info('DropIndex {}'.format(_table_name)) - _status = self._drop_index(_table_name) - return status_pb2.Status(error_code=_status.code, - reason=_status.message) diff --git a/mishards/settings.py b/mishards/settings.py deleted file mode 100644 index 21a3bb7a65..0000000000 --- a/mishards/settings.py +++ /dev/null @@ -1,94 +0,0 @@ -import sys -import os - -from environs import Env -env = Env() - -FROM_EXAMPLE = env.bool('FROM_EXAMPLE', False) -if FROM_EXAMPLE: - from dotenv import load_dotenv - load_dotenv('./mishards/.env.example') -else: - env.read_env() - -DEBUG = env.bool('DEBUG', False) - -LOG_LEVEL = env.str('LOG_LEVEL', 'DEBUG' if DEBUG else 'INFO') -LOG_PATH = env.str('LOG_PATH', '/tmp/mishards') -LOG_NAME = env.str('LOG_NAME', 'logfile') -TIMEZONE = env.str('TIMEZONE', 'UTC') - -from utils.logger_helper import config -config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE) - -TIMEOUT = env.int('TIMEOUT', 60) -MAX_RETRY = env.int('MAX_RETRY', 3) - -SERVER_PORT = env.int('SERVER_PORT', 19530) -SERVER_TEST_PORT = env.int('SERVER_TEST_PORT', 19530) -WOSERVER = env.str('WOSERVER') - -SD_PROVIDER_SETTINGS = None -SD_PROVIDER = env.str('SD_PROVIDER', 'Kubernetes') -if SD_PROVIDER == 'Kubernetes': - from sd.kubernetes_provider import KubernetesProviderSettings - SD_PROVIDER_SETTINGS = KubernetesProviderSettings( - namespace=env.str('SD_NAMESPACE', ''), - in_cluster=env.bool('SD_IN_CLUSTER', False), - poll_interval=env.int('SD_POLL_INTERVAL', 5), - pod_patt=env.str('SD_ROSERVER_POD_PATT', ''), - label_selector=env.str('SD_LABEL_SELECTOR', ''), - port=env.int('SD_PORT', 19530)) -elif SD_PROVIDER == 'Static': - from sd.static_provider import StaticProviderSettings - SD_PROVIDER_SETTINGS = StaticProviderSettings( - hosts=env.list('SD_STATIC_HOSTS', []), - port=env.int('SD_STATIC_PORT', 19530)) - -# TESTING_WOSERVER = env.str('TESTING_WOSERVER', 'tcp://127.0.0.1:19530') - - -class TracingConfig: - TRACING_SERVICE_NAME = env.str('TRACING_SERVICE_NAME', 'mishards') - TRACING_VALIDATE = env.bool('TRACING_VALIDATE', True) - TRACING_LOG_PAYLOAD = env.bool('TRACING_LOG_PAYLOAD', False) - TRACING_CONFIG = { - 'sampler': { - 'type': env.str('TRACING_SAMPLER_TYPE', 'const'), - 'param': env.str('TRACING_SAMPLER_PARAM', "1"), - }, - 'local_agent': { - 'reporting_host': env.str('TRACING_REPORTING_HOST', '127.0.0.1'), - 'reporting_port': env.str('TRACING_REPORTING_PORT', '5775') - }, - 'logging': env.bool('TRACING_LOGGING', True) - } - DEFAULT_TRACING_CONFIG = { - 'sampler': { - 'type': env.str('TRACING_SAMPLER_TYPE', 'const'), - 'param': env.str('TRACING_SAMPLER_PARAM', "0"), - } - } - - -class DefaultConfig: - SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI') - SQL_ECHO = env.bool('SQL_ECHO', False) - TRACING_TYPE = env.str('TRACING_TYPE', '') - ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_NAME', 'FileBasedHashRingRouter') - - -class TestingConfig(DefaultConfig): - SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_TEST_URI', '') - SQL_ECHO = env.bool('SQL_TEST_ECHO', False) - TRACING_TYPE = env.str('TRACING_TEST_TYPE', '') - ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_TEST_NAME', 'FileBasedHashRingRouter') - - -if __name__ == '__main__': - import logging - logger = logging.getLogger(__name__) - logger.debug('DEBUG') - logger.info('INFO') - logger.warn('WARN') - logger.error('ERROR') diff --git a/mishards/test_connections.py b/mishards/test_connections.py deleted file mode 100644 index 819d2e03da..0000000000 --- a/mishards/test_connections.py +++ /dev/null @@ -1,101 +0,0 @@ -import logging -import pytest -import mock - -from milvus import Milvus -from mishards.connections import (ConnectionMgr, Connection) -from mishards import exceptions - -logger = logging.getLogger(__name__) - - -@pytest.mark.usefixtures('app') -class TestConnection: - def test_manager(self): - mgr = ConnectionMgr() - - mgr.register('pod1', '111') - mgr.register('pod2', '222') - mgr.register('pod2', '222') - mgr.register('pod2', '2222') - assert len(mgr.conn_names) == 2 - - mgr.unregister('pod1') - assert len(mgr.conn_names) == 1 - - mgr.unregister('pod2') - assert len(mgr.conn_names) == 0 - - mgr.register('WOSERVER', 'xxxx') - assert len(mgr.conn_names) == 0 - - assert not mgr.conn('XXXX', None) - with pytest.raises(exceptions.ConnectionNotFoundError): - mgr.conn('XXXX', None, True) - - mgr.conn('WOSERVER', None) - - def test_connection(self): - class Conn: - def __init__(self, state): - self.state = state - - def connect(self, uri): - return self.state - - def connected(self): - return self.state - - FAIL_CONN = Conn(False) - PASS_CONN = Conn(True) - - class Retry: - def __init__(self): - self.times = 0 - - def __call__(self, conn): - self.times += 1 - logger.info('Retrying {}'.format(self.times)) - - class Func(): - def __init__(self): - self.executed = False - - def __call__(self): - self.executed = True - - max_retry = 3 - - RetryObj = Retry() - - c = Connection('client', - uri='xx', - max_retry=max_retry, - on_retry_func=RetryObj) - c.conn = FAIL_CONN - ff = Func() - this_connect = c.connect(func=ff) - with pytest.raises(exceptions.ConnectionConnectError): - this_connect() - assert RetryObj.times == max_retry - assert not ff.executed - RetryObj = Retry() - - c.conn = PASS_CONN - this_connect = c.connect(func=ff) - this_connect() - assert ff.executed - assert RetryObj.times == 0 - - this_connect = c.connect(func=None) - with pytest.raises(TypeError): - this_connect() - - errors = [] - - def error_handler(err): - errors.append(err) - - this_connect = c.connect(func=None, exception_handler=error_handler) - this_connect() - assert len(errors) == 1 diff --git a/mishards/test_models.py b/mishards/test_models.py deleted file mode 100644 index d60b62713e..0000000000 --- a/mishards/test_models.py +++ /dev/null @@ -1,39 +0,0 @@ -import logging -import pytest -from mishards.factories import TableFiles, Tables, TableFilesFactory, TablesFactory -from mishards import db, create_app, settings -from mishards.factories import ( - Tables, TableFiles, - TablesFactory, TableFilesFactory -) - -logger = logging.getLogger(__name__) - - -@pytest.mark.usefixtures('app') -class TestModels: - def test_files_to_search(self): - table = TablesFactory() - new_files_cnt = 5 - to_index_cnt = 10 - raw_cnt = 20 - backup_cnt = 12 - to_delete_cnt = 9 - index_cnt = 8 - new_index_cnt = 6 - new_merge_cnt = 11 - - new_files = TableFilesFactory.create_batch(new_files_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW, date=110) - to_index_files = TableFilesFactory.create_batch(to_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_INDEX, date=110) - raw_files = TableFilesFactory.create_batch(raw_cnt, table=table, file_type=TableFiles.FILE_TYPE_RAW, date=120) - backup_files = TableFilesFactory.create_batch(backup_cnt, table=table, file_type=TableFiles.FILE_TYPE_BACKUP, date=110) - index_files = TableFilesFactory.create_batch(index_cnt, table=table, file_type=TableFiles.FILE_TYPE_INDEX, date=110) - new_index_files = TableFilesFactory.create_batch(new_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW_INDEX, date=110) - new_merge_files = TableFilesFactory.create_batch(new_merge_cnt, table=table, file_type=TableFiles.FILE_TYPE_NEW_MERGE, date=110) - to_delete_files = TableFilesFactory.create_batch(to_delete_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_DELETE, date=110) - assert table.files_to_search().count() == raw_cnt + index_cnt + to_index_cnt - - assert table.files_to_search([(100, 115)]).count() == index_cnt + to_index_cnt - assert table.files_to_search([(111, 120)]).count() == 0 - assert table.files_to_search([(111, 121)]).count() == raw_cnt - assert table.files_to_search([(110, 121)]).count() == raw_cnt + index_cnt + to_index_cnt diff --git a/mishards/test_server.py b/mishards/test_server.py deleted file mode 100644 index efd3912076..0000000000 --- a/mishards/test_server.py +++ /dev/null @@ -1,279 +0,0 @@ -import logging -import pytest -import mock -import datetime -import random -import faker -import inspect -from milvus import Milvus -from milvus.client.types import Status, IndexType, MetricType -from milvus.client.abstract import IndexParam, TableSchema -from milvus.grpc_gen import status_pb2, milvus_pb2 -from mishards import db, create_app, settings -from mishards.service_handler import ServiceHandler -from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser -from mishards.factories import TableFilesFactory, TablesFactory, TableFiles, Tables -from mishards.routings import RouterMixin - -logger = logging.getLogger(__name__) - -OK = Status(code=Status.SUCCESS, message='Success') -BAD = Status(code=Status.PERMISSION_DENIED, message='Fail') - - -@pytest.mark.usefixtures('started_app') -class TestServer: - @property - def client(self): - m = Milvus() - m.connect(host='localhost', port=settings.SERVER_TEST_PORT) - return m - - def test_server_start(self, started_app): - assert started_app.conn_mgr.metas.get('WOSERVER') == settings.WOSERVER - - def test_cmd(self, started_app): - ServiceHandler._get_server_version = mock.MagicMock(return_value=(OK, - '')) - status, _ = self.client.server_version() - assert status.OK() - - Parser.parse_proto_Command = mock.MagicMock(return_value=(BAD, 'cmd')) - status, _ = self.client.server_version() - assert not status.OK() - - def test_drop_index(self, started_app): - table_name = inspect.currentframe().f_code.co_name - ServiceHandler._drop_index = mock.MagicMock(return_value=OK) - status = self.client.drop_index(table_name) - assert status.OK() - - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(BAD, table_name)) - status = self.client.drop_index(table_name) - assert not status.OK() - - def test_describe_index(self, started_app): - table_name = inspect.currentframe().f_code.co_name - index_type = IndexType.FLAT - nlist = 1 - index_param = IndexParam(table_name=table_name, - index_type=index_type, - nlist=nlist) - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(OK, table_name)) - ServiceHandler._describe_index = mock.MagicMock( - return_value=(OK, index_param)) - status, ret = self.client.describe_index(table_name) - assert status.OK() - assert ret._table_name == index_param._table_name - - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(BAD, table_name)) - status, _ = self.client.describe_index(table_name) - assert not status.OK() - - def test_preload(self, started_app): - table_name = inspect.currentframe().f_code.co_name - - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(OK, table_name)) - ServiceHandler._preload_table = mock.MagicMock(return_value=OK) - status = self.client.preload_table(table_name) - assert status.OK() - - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(BAD, table_name)) - status = self.client.preload_table(table_name) - assert not status.OK() - - @pytest.mark.skip - def test_delete_by_range(self, started_app): - table_name = inspect.currentframe().f_code.co_name - - unpacked = table_name, datetime.datetime.today( - ), datetime.datetime.today() - - Parser.parse_proto_DeleteByRangeParam = mock.MagicMock( - return_value=(OK, unpacked)) - ServiceHandler._delete_by_range = mock.MagicMock(return_value=OK) - status = self.client.delete_vectors_by_range( - *unpacked) - assert status.OK() - - Parser.parse_proto_DeleteByRangeParam = mock.MagicMock( - return_value=(BAD, unpacked)) - status = self.client.delete_vectors_by_range( - *unpacked) - assert not status.OK() - - def test_count_table(self, started_app): - table_name = inspect.currentframe().f_code.co_name - count = random.randint(100, 200) - - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(OK, table_name)) - ServiceHandler._count_table = mock.MagicMock(return_value=(OK, count)) - status, ret = self.client.get_table_row_count(table_name) - assert status.OK() - assert ret == count - - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(BAD, table_name)) - status, _ = self.client.get_table_row_count(table_name) - assert not status.OK() - - def test_show_tables(self, started_app): - tables = ['t1', 't2'] - ServiceHandler._show_tables = mock.MagicMock(return_value=(OK, tables)) - status, ret = self.client.show_tables() - assert status.OK() - assert ret == tables - - def test_describe_table(self, started_app): - table_name = inspect.currentframe().f_code.co_name - dimension = 128 - nlist = 1 - table_schema = TableSchema(table_name=table_name, - index_file_size=100, - metric_type=MetricType.L2, - dimension=dimension) - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(OK, table_schema.table_name)) - ServiceHandler._describe_table = mock.MagicMock( - return_value=(OK, table_schema)) - status, _ = self.client.describe_table(table_name) - assert status.OK() - - ServiceHandler._describe_table = mock.MagicMock( - return_value=(BAD, table_schema)) - status, _ = self.client.describe_table(table_name) - assert not status.OK() - - Parser.parse_proto_TableName = mock.MagicMock(return_value=(BAD, - 'cmd')) - status, ret = self.client.describe_table(table_name) - assert not status.OK() - - def test_insert(self, started_app): - table_name = inspect.currentframe().f_code.co_name - vectors = [[random.random() for _ in range(16)] for _ in range(10)] - ids = [random.randint(1000000, 20000000) for _ in range(10)] - ServiceHandler._add_vectors = mock.MagicMock(return_value=(OK, ids)) - status, ret = self.client.add_vectors( - table_name=table_name, records=vectors) - assert status.OK() - assert ids == ret - - def test_create_index(self, started_app): - table_name = inspect.currentframe().f_code.co_name - unpacks = table_name, None - Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(OK, - unpacks)) - ServiceHandler._create_index = mock.MagicMock(return_value=OK) - status = self.client.create_index(table_name=table_name) - assert status.OK() - - Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(BAD, - None)) - status = self.client.create_index(table_name=table_name) - assert not status.OK() - - def test_drop_table(self, started_app): - table_name = inspect.currentframe().f_code.co_name - - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(OK, table_name)) - ServiceHandler._delete_table = mock.MagicMock(return_value=OK) - status = self.client.delete_table(table_name=table_name) - assert status.OK() - - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(BAD, table_name)) - status = self.client.delete_table(table_name=table_name) - assert not status.OK() - - def test_has_table(self, started_app): - table_name = inspect.currentframe().f_code.co_name - - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(OK, table_name)) - ServiceHandler._has_table = mock.MagicMock(return_value=(OK, True)) - has = self.client.has_table(table_name=table_name) - assert has - - Parser.parse_proto_TableName = mock.MagicMock( - return_value=(BAD, table_name)) - status, has = self.client.has_table(table_name=table_name) - assert not status.OK() - assert not has - - def test_create_table(self, started_app): - table_name = inspect.currentframe().f_code.co_name - dimension = 128 - table_schema = dict(table_name=table_name, - index_file_size=100, - metric_type=MetricType.L2, - dimension=dimension) - - ServiceHandler._create_table = mock.MagicMock(return_value=OK) - status = self.client.create_table(table_schema) - assert status.OK() - - Parser.parse_proto_TableSchema = mock.MagicMock(return_value=(BAD, - None)) - status = self.client.create_table(table_schema) - assert not status.OK() - - def random_data(self, n, dimension): - return [[random.random() for _ in range(dimension)] for _ in range(n)] - - def test_search(self, started_app): - table_name = inspect.currentframe().f_code.co_name - to_index_cnt = random.randint(10, 20) - table = TablesFactory(table_id=table_name, state=Tables.NORMAL) - to_index_files = TableFilesFactory.create_batch( - to_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_INDEX) - topk = random.randint(5, 10) - nq = random.randint(5, 10) - param = { - 'table_name': table_name, - 'query_records': self.random_data(nq, table.dimension), - 'top_k': topk, - 'nprobe': 2049 - } - - result = [ - milvus_pb2.TopKQueryResult(query_result_arrays=[ - milvus_pb2.QueryResult(id=i, distance=random.random()) - for i in range(topk) - ]) for i in range(nq) - ] - - mock_results = milvus_pb2.TopKQueryResultList(status=status_pb2.Status( - error_code=status_pb2.SUCCESS, reason="Success"), - topk_query_result=result) - - table_schema = TableSchema(table_name=table_name, - index_file_size=table.index_file_size, - metric_type=table.metric_type, - dimension=table.dimension) - - status, _ = self.client.search_vectors(**param) - assert status.code == Status.ILLEGAL_ARGUMENT - - param['nprobe'] = 2048 - RouterMixin.connection = mock.MagicMock(return_value=Milvus()) - RouterMixin.query_conn = mock.MagicMock(return_value=Milvus()) - Milvus.describe_table = mock.MagicMock(return_value=(BAD, - table_schema)) - status, ret = self.client.search_vectors(**param) - assert status.code == Status.TABLE_NOT_EXISTS - - Milvus.describe_table = mock.MagicMock(return_value=(OK, table_schema)) - Milvus.search_vectors_in_files = mock.MagicMock( - return_value=mock_results) - - status, ret = self.client.search_vectors(**param) - assert status.OK() - assert len(ret) == nq diff --git a/mishards/utilities.py b/mishards/utilities.py deleted file mode 100644 index 42e982b5f1..0000000000 --- a/mishards/utilities.py +++ /dev/null @@ -1,20 +0,0 @@ -import datetime -from mishards import exceptions - - -def format_date(start, end): - return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day, - (end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day) - - -def range_to_date(range_obj, metadata=None): - try: - start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d') - end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d') - assert start < end - except (ValueError, AssertionError): - raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format( - range_obj.start_date, range_obj.end_date), - metadata=metadata) - - return format_date(start, end) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index ae224e92ed..0000000000 --- a/requirements.txt +++ /dev/null @@ -1,36 +0,0 @@ -environs==4.2.0 -factory-boy==2.12.0 -Faker==1.0.7 -fire==0.1.3 -google-auth==1.6.3 -grpcio==1.22.0 -grpcio-tools==1.22.0 -kubernetes==10.0.1 -MarkupSafe==1.1.1 -marshmallow==2.19.5 -pymysql==0.9.3 -protobuf==3.9.1 -py==1.8.0 -pyasn1==0.4.7 -pyasn1-modules==0.2.6 -pylint==2.3.1 -pymilvus-test==0.2.28 -#pymilvus==0.2.0 -pyparsing==2.4.0 -pytest==4.6.3 -pytest-level==0.1.1 -pytest-print==0.1.2 -pytest-repeat==0.8.0 -pytest-timeout==1.3.3 -python-dateutil==2.8.0 -python-dotenv==0.10.3 -pytz==2019.1 -requests==2.22.0 -requests-oauthlib==1.2.0 -rsa==4.0 -six==1.12.0 -SQLAlchemy==1.3.5 -urllib3==1.25.3 -jaeger-client>=3.4.0 -grpcio-opentracing>=1.0 -mock==2.0.0 diff --git a/sd/__init__.py b/sd/__init__.py deleted file mode 100644 index 7943887d0f..0000000000 --- a/sd/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -import logging -import inspect -# from utils import singleton - -logger = logging.getLogger(__name__) - - -class ProviderManager: - PROVIDERS = {} - - @classmethod - def register_service_provider(cls, target): - if inspect.isfunction(target): - cls.PROVIDERS[target.__name__] = target - elif inspect.isclass(target): - name = target.__dict__.get('NAME', None) - name = name if name else target.__class__.__name__ - cls.PROVIDERS[name] = target - else: - assert False, 'Cannot register_service_provider for: {}'.format(target) - return target - - @classmethod - def get_provider(cls, name): - return cls.PROVIDERS.get(name, None) - - -from sd import kubernetes_provider, static_provider diff --git a/sd/kubernetes_provider.py b/sd/kubernetes_provider.py deleted file mode 100644 index eb113db007..0000000000 --- a/sd/kubernetes_provider.py +++ /dev/null @@ -1,331 +0,0 @@ -import os -import sys -if __name__ == '__main__': - sys.path.append(os.path.dirname(os.path.dirname( - os.path.abspath(__file__)))) - -import re -import logging -import time -import copy -import threading -import queue -import enum -from kubernetes import client, config, watch - -from utils import singleton -from sd import ProviderManager - -logger = logging.getLogger(__name__) - -INCLUSTER_NAMESPACE_PATH = '/var/run/secrets/kubernetes.io/serviceaccount/namespace' - - -class EventType(enum.Enum): - PodHeartBeat = 1 - Watch = 2 - - -class K8SMixin: - def __init__(self, namespace, in_cluster=False, **kwargs): - self.namespace = namespace - self.in_cluster = in_cluster - self.kwargs = kwargs - self.v1 = kwargs.get('v1', None) - if not self.namespace: - self.namespace = open(INCLUSTER_NAMESPACE_PATH).read() - - if not self.v1: - config.load_incluster_config( - ) if self.in_cluster else config.load_kube_config() - self.v1 = client.CoreV1Api() - - -class K8SHeartbeatHandler(threading.Thread, K8SMixin): - def __init__(self, - message_queue, - namespace, - label_selector, - in_cluster=False, - **kwargs): - K8SMixin.__init__(self, - namespace=namespace, - in_cluster=in_cluster, - **kwargs) - threading.Thread.__init__(self) - self.queue = message_queue - self.terminate = False - self.label_selector = label_selector - self.poll_interval = kwargs.get('poll_interval', 5) - - def run(self): - while not self.terminate: - try: - pods = self.v1.list_namespaced_pod( - namespace=self.namespace, - label_selector=self.label_selector) - event_message = {'eType': EventType.PodHeartBeat, 'events': []} - for item in pods.items: - pod = self.v1.read_namespaced_pod(name=item.metadata.name, - namespace=self.namespace) - name = pod.metadata.name - ip = pod.status.pod_ip - phase = pod.status.phase - reason = pod.status.reason - message = pod.status.message - ready = True if phase == 'Running' else False - - pod_event = dict(pod=name, - ip=ip, - ready=ready, - reason=reason, - message=message) - - event_message['events'].append(pod_event) - - self.queue.put(event_message) - - except Exception as exc: - logger.error(exc) - - time.sleep(self.poll_interval) - - def stop(self): - self.terminate = True - - -class K8SEventListener(threading.Thread, K8SMixin): - def __init__(self, message_queue, namespace, in_cluster=False, **kwargs): - K8SMixin.__init__(self, - namespace=namespace, - in_cluster=in_cluster, - **kwargs) - threading.Thread.__init__(self) - self.queue = message_queue - self.terminate = False - self.at_start_up = True - self._stop_event = threading.Event() - - def stop(self): - self.terminate = True - self._stop_event.set() - - def run(self): - resource_version = '' - w = watch.Watch() - for event in w.stream(self.v1.list_namespaced_event, - namespace=self.namespace, - field_selector='involvedObject.kind=Pod'): - if self.terminate: - break - - resource_version = int(event['object'].metadata.resource_version) - - info = dict( - eType=EventType.Watch, - pod=event['object'].involved_object.name, - reason=event['object'].reason, - message=event['object'].message, - start_up=self.at_start_up, - ) - self.at_start_up = False - # logger.info('Received event: {}'.format(info)) - self.queue.put(info) - - -class EventHandler(threading.Thread): - def __init__(self, mgr, message_queue, namespace, pod_patt, **kwargs): - threading.Thread.__init__(self) - self.mgr = mgr - self.queue = message_queue - self.kwargs = kwargs - self.terminate = False - self.pod_patt = re.compile(pod_patt) - self.namespace = namespace - - def stop(self): - self.terminate = True - - def on_drop(self, event, **kwargs): - pass - - def on_pod_started(self, event, **kwargs): - try_cnt = 3 - pod = None - while try_cnt > 0: - try_cnt -= 1 - try: - pod = self.mgr.v1.read_namespaced_pod(name=event['pod'], - namespace=self.namespace) - if not pod.status.pod_ip: - time.sleep(0.5) - continue - break - except client.rest.ApiException as exc: - time.sleep(0.5) - - if try_cnt <= 0 and not pod: - if not event['start_up']: - logger.error('Pod {} is started but cannot read pod'.format( - event['pod'])) - return - elif try_cnt <= 0 and not pod.status.pod_ip: - logger.warning('NoPodIPFoundError') - return - - logger.info('Register POD {} with IP {}'.format( - pod.metadata.name, pod.status.pod_ip)) - self.mgr.add_pod(name=pod.metadata.name, ip=pod.status.pod_ip) - - def on_pod_killing(self, event, **kwargs): - logger.info('Unregister POD {}'.format(event['pod'])) - self.mgr.delete_pod(name=event['pod']) - - def on_pod_heartbeat(self, event, **kwargs): - names = self.mgr.conn_mgr.conn_names - - running_names = set() - for each_event in event['events']: - if each_event['ready']: - self.mgr.add_pod(name=each_event['pod'], ip=each_event['ip']) - running_names.add(each_event['pod']) - else: - self.mgr.delete_pod(name=each_event['pod']) - - to_delete = names - running_names - for name in to_delete: - self.mgr.delete_pod(name) - - logger.info(self.mgr.conn_mgr.conn_names) - - def handle_event(self, event): - if event['eType'] == EventType.PodHeartBeat: - return self.on_pod_heartbeat(event) - - if not event or (event['reason'] not in ('Started', 'Killing')): - return self.on_drop(event) - - if not re.match(self.pod_patt, event['pod']): - return self.on_drop(event) - - logger.info('Handling event: {}'.format(event)) - - if event['reason'] == 'Started': - return self.on_pod_started(event) - - return self.on_pod_killing(event) - - def run(self): - while not self.terminate: - try: - event = self.queue.get(timeout=1) - self.handle_event(event) - except queue.Empty: - continue - - -class KubernetesProviderSettings: - def __init__(self, namespace, pod_patt, label_selector, in_cluster, - poll_interval, port=None, **kwargs): - self.namespace = namespace - self.pod_patt = pod_patt - self.label_selector = label_selector - self.in_cluster = in_cluster - self.poll_interval = poll_interval - self.port = int(port) if port else 19530 - - -@singleton -@ProviderManager.register_service_provider -class KubernetesProvider(object): - NAME = 'Kubernetes' - - def __init__(self, settings, conn_mgr, **kwargs): - self.namespace = settings.namespace - self.pod_patt = settings.pod_patt - self.label_selector = settings.label_selector - self.in_cluster = settings.in_cluster - self.poll_interval = settings.poll_interval - self.port = settings.port - self.kwargs = kwargs - self.queue = queue.Queue() - - self.conn_mgr = conn_mgr - - if not self.namespace: - self.namespace = open(incluster_namespace_path).read() - - config.load_incluster_config( - ) if self.in_cluster else config.load_kube_config() - self.v1 = client.CoreV1Api() - - self.listener = K8SEventListener(message_queue=self.queue, - namespace=self.namespace, - in_cluster=self.in_cluster, - v1=self.v1, - **kwargs) - - self.pod_heartbeater = K8SHeartbeatHandler( - message_queue=self.queue, - namespace=self.namespace, - label_selector=self.label_selector, - in_cluster=self.in_cluster, - v1=self.v1, - poll_interval=self.poll_interval, - **kwargs) - - self.event_handler = EventHandler(mgr=self, - message_queue=self.queue, - namespace=self.namespace, - pod_patt=self.pod_patt, - **kwargs) - - def add_pod(self, name, ip): - self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port)) - - def delete_pod(self, name): - self.conn_mgr.unregister(name) - - def start(self): - self.listener.daemon = True - self.listener.start() - self.event_handler.start() - - self.pod_heartbeater.start() - - def stop(self): - self.listener.stop() - self.pod_heartbeater.stop() - self.event_handler.stop() - - -if __name__ == '__main__': - logging.basicConfig(level=logging.INFO) - - class Connect: - def register(self, name, value): - logger.error('Register: {} - {}'.format(name, value)) - - def unregister(self, name): - logger.error('Unregister: {}'.format(name)) - - @property - def conn_names(self): - return set() - - connect_mgr = Connect() - - settings = KubernetesProviderSettings(namespace='xp', - pod_patt=".*-ro-servers-.*", - label_selector='tier=ro-servers', - poll_interval=5, - in_cluster=False) - - provider_class = ProviderManager.get_provider('Kubernetes') - t = provider_class(conn_mgr=connect_mgr, settings=settings) - t.start() - cnt = 100 - while cnt > 0: - time.sleep(2) - cnt -= 1 - t.stop() diff --git a/sd/static_provider.py b/sd/static_provider.py deleted file mode 100644 index e88780740f..0000000000 --- a/sd/static_provider.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -import sys -if __name__ == '__main__': - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import socket -from utils import singleton -from sd import ProviderManager - - -class StaticProviderSettings: - def __init__(self, hosts, port=None): - self.hosts = hosts - self.port = int(port) if port else 19530 - - -@singleton -@ProviderManager.register_service_provider -class KubernetesProvider(object): - NAME = 'Static' - - def __init__(self, settings, conn_mgr, **kwargs): - self.conn_mgr = conn_mgr - self.hosts = [socket.gethostbyname(host) for host in settings.hosts] - self.port = settings.port - - def start(self): - for host in self.hosts: - self.add_pod(host, host) - - def stop(self): - for host in self.hosts: - self.delete_pod(host) - - def add_pod(self, name, ip): - self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port)) - - def delete_pod(self, name): - self.conn_mgr.unregister(name) diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 4a88432914..0000000000 --- a/setup.cfg +++ /dev/null @@ -1,4 +0,0 @@ -[tool:pytest] -testpaths = mishards -log_cli=true -log_cli_level=info diff --git a/start_services.yml b/start_services.yml deleted file mode 100644 index 57fe061bb7..0000000000 --- a/start_services.yml +++ /dev/null @@ -1,45 +0,0 @@ -version: "2.3" -services: - milvus: - runtime: nvidia - restart: always - image: registry.zilliz.com/milvus/engine:branch-0.5.0-release-4316de - # ports: - # - "0.0.0.0:19530:19530" - volumes: - - /tmp/milvus/db:/opt/milvus/db - - jaeger: - restart: always - image: jaegertracing/all-in-one:1.14 - ports: - - "0.0.0.0:5775:5775/udp" - - "0.0.0.0:16686:16686" - - "0.0.0.0:9441:9441" - environment: - COLLECTOR_ZIPKIN_HTTP_PORT: 9411 - - mishards: - restart: always - image: registry.zilliz.com/milvus/mishards:v0.0.4 - ports: - - "0.0.0.0:19530:19531" - - "0.0.0.0:19532:19532" - volumes: - - /tmp/milvus/db:/tmp/milvus/db - # - /tmp/mishards_env:/source/mishards/.env - command: ["python", "mishards/main.py"] - environment: - FROM_EXAMPLE: 'true' - DEBUG: 'true' - SERVER_PORT: 19531 - WOSERVER: tcp://milvus:19530 - SD_STATIC_HOSTS: milvus - TRACING_TYPE: jaeger - TRACING_SERVICE_NAME: mishards-demo - TRACING_REPORTING_HOST: jaeger - TRACING_REPORTING_PORT: 5775 - - depends_on: - - milvus - - jaeger diff --git a/tracing/__init__.py b/tracing/__init__.py deleted file mode 100644 index 64a5b50d15..0000000000 --- a/tracing/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -from contextlib import contextmanager - - -def empty_server_interceptor_decorator(target_server, interceptor): - return target_server - - -@contextmanager -def EmptySpan(*args, **kwargs): - yield None - return - - -class Tracer: - def __init__(self, - tracer=None, - interceptor=None, - server_decorator=empty_server_interceptor_decorator): - self.tracer = tracer - self.interceptor = interceptor - self.server_decorator = server_decorator - - def decorate(self, server): - return self.server_decorator(server, self.interceptor) - - @property - def empty(self): - return self.tracer is None - - def close(self): - self.tracer and self.tracer.close() - - def start_span(self, - operation_name=None, - child_of=None, - references=None, - tags=None, - start_time=None, - ignore_active_span=False): - if self.empty: - return EmptySpan() - return self.tracer.start_span(operation_name, child_of, references, - tags, start_time, ignore_active_span) diff --git a/tracing/factory.py b/tracing/factory.py deleted file mode 100644 index 14fcde2eb3..0000000000 --- a/tracing/factory.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging -from jaeger_client import Config -from grpc_opentracing.grpcext import intercept_server -from grpc_opentracing import open_tracing_server_interceptor - -from tracing import (Tracer, empty_server_interceptor_decorator) - -logger = logging.getLogger(__name__) - - -class TracerFactory: - @classmethod - def new_tracer(cls, - tracer_type, - tracer_config, - span_decorator=None, - **kwargs): - if not tracer_type: - return Tracer() - config = tracer_config.TRACING_CONFIG - service_name = tracer_config.TRACING_SERVICE_NAME - validate = tracer_config.TRACING_VALIDATE - # if not tracer_type: - # tracer_type = 'jaeger' - # config = tracer_config.DEFAULT_TRACING_CONFIG - - if tracer_type.lower() == 'jaeger': - config = Config(config=config, - service_name=service_name, - validate=validate) - - tracer = config.initialize_tracer() - tracer_interceptor = open_tracing_server_interceptor( - tracer, - log_payloads=tracer_config.TRACING_LOG_PAYLOAD, - span_decorator=span_decorator) - - return Tracer(tracer, tracer_interceptor, intercept_server) - - assert False, 'Unsupported tracer type: {}'.format(tracer_type) diff --git a/utils/__init__.py b/utils/__init__.py deleted file mode 100644 index c1d55e76c0..0000000000 --- a/utils/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from functools import wraps - - -def singleton(cls): - instances = {} - @wraps(cls) - def getinstance(*args, **kw): - if cls not in instances: - instances[cls] = cls(*args, **kw) - return instances[cls] - return getinstance diff --git a/utils/logger_helper.py b/utils/logger_helper.py deleted file mode 100644 index b4e3b9c5b6..0000000000 --- a/utils/logger_helper.py +++ /dev/null @@ -1,152 +0,0 @@ -import os -import datetime -from pytz import timezone -from logging import Filter -import logging.config - - -class InfoFilter(logging.Filter): - def filter(self, rec): - return rec.levelno == logging.INFO - - -class DebugFilter(logging.Filter): - def filter(self, rec): - return rec.levelno == logging.DEBUG - - -class WarnFilter(logging.Filter): - def filter(self, rec): - return rec.levelno == logging.WARN - - -class ErrorFilter(logging.Filter): - def filter(self, rec): - return rec.levelno == logging.ERROR - - -class CriticalFilter(logging.Filter): - def filter(self, rec): - return rec.levelno == logging.CRITICAL - - -COLORS = { - 'HEADER': '\033[95m', - 'INFO': '\033[92m', - 'DEBUG': '\033[94m', - 'WARNING': '\033[93m', - 'ERROR': '\033[95m', - 'CRITICAL': '\033[91m', - 'ENDC': '\033[0m', -} - - -class ColorFulFormatColMixin: - def format_col(self, message_str, level_name): - if level_name in COLORS.keys(): - message_str = COLORS.get(level_name) + message_str + COLORS.get( - 'ENDC') - return message_str - - -class ColorfulFormatter(logging.Formatter, ColorFulFormatColMixin): - def format(self, record): - message_str = super(ColorfulFormatter, self).format(record) - - return self.format_col(message_str, level_name=record.levelname) - - -def config(log_level, log_path, name, tz='UTC'): - def build_log_file(level, log_path, name, tz): - utc_now = datetime.datetime.utcnow() - utc_tz = timezone('UTC') - local_tz = timezone(tz) - tznow = utc_now.replace(tzinfo=utc_tz).astimezone(local_tz) - return '{}-{}-{}.log'.format(os.path.join(log_path, name), tznow.strftime("%m-%d-%Y-%H:%M:%S"), - level) - - if not os.path.exists(log_path): - os.makedirs(log_path) - - LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'default': { - 'format': '%(asctime)s | %(levelname)s | %(name)s | %(threadName)s: %(message)s (%(filename)s:%(lineno)s)', - }, - 'colorful_console': { - 'format': '%(asctime)s | %(levelname)s | %(name)s | %(threadName)s: %(message)s (%(filename)s:%(lineno)s)', - '()': ColorfulFormatter, - }, - }, - 'filters': { - 'InfoFilter': { - '()': InfoFilter, - }, - 'DebugFilter': { - '()': DebugFilter, - }, - 'WarnFilter': { - '()': WarnFilter, - }, - 'ErrorFilter': { - '()': ErrorFilter, - }, - 'CriticalFilter': { - '()': CriticalFilter, - }, - }, - 'handlers': { - 'milvus_celery_console': { - 'class': 'logging.StreamHandler', - 'formatter': 'colorful_console', - }, - 'milvus_debug_file': { - 'level': 'DEBUG', - 'filters': ['DebugFilter'], - 'class': 'logging.handlers.RotatingFileHandler', - 'formatter': 'default', - 'filename': build_log_file('debug', log_path, name, tz) - }, - 'milvus_info_file': { - 'level': 'INFO', - 'filters': ['InfoFilter'], - 'class': 'logging.handlers.RotatingFileHandler', - 'formatter': 'default', - 'filename': build_log_file('info', log_path, name, tz) - }, - 'milvus_warn_file': { - 'level': 'WARN', - 'filters': ['WarnFilter'], - 'class': 'logging.handlers.RotatingFileHandler', - 'formatter': 'default', - 'filename': build_log_file('warn', log_path, name, tz) - }, - 'milvus_error_file': { - 'level': 'ERROR', - 'filters': ['ErrorFilter'], - 'class': 'logging.handlers.RotatingFileHandler', - 'formatter': 'default', - 'filename': build_log_file('error', log_path, name, tz) - }, - 'milvus_critical_file': { - 'level': 'CRITICAL', - 'filters': ['CriticalFilter'], - 'class': 'logging.handlers.RotatingFileHandler', - 'formatter': 'default', - 'filename': build_log_file('critical', log_path, name, tz) - }, - }, - 'loggers': { - '': { - 'handlers': ['milvus_celery_console', 'milvus_info_file', 'milvus_debug_file', 'milvus_warn_file', - 'milvus_error_file', 'milvus_critical_file'], - 'level': log_level, - 'propagate': False - }, - }, - 'propagate': False, - } - - logging.config.dictConfig(LOGGING)