diff --git a/mishards/__init__.py b/mishards/__init__.py index 4bd77d8c60..759e8c2e5a 100644 --- a/mishards/__init__.py +++ b/mishards/__init__.py @@ -27,7 +27,10 @@ def create_app(testing_config=None): tracer = TracerFactory.new_tracer(config.TRACING_TYPE, settings.TracingConfig, span_decorator=GrpcSpanDecorator()) - grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, discover=discover) + from mishards.routings import RouterFactory + router = RouterFactory.new_router(config.ROUTER_CLASS_NAME, connect_mgr) + + grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, router=router, discover=discover) from mishards import exception_handlers diff --git a/mishards/routings.py b/mishards/routings.py new file mode 100644 index 0000000000..a61352f40b --- /dev/null +++ b/mishards/routings.py @@ -0,0 +1,81 @@ +import logging +from sqlalchemy import exc as sqlalchemy_exc +from sqlalchemy import and_ + +from mishards import exceptions, db +from mishards.hash_ring import HashRing +from mishards.models import Tables + +logger = logging.getLogger(__name__) + + +class RouteManager: + ROUTER_CLASSES = {} + + @classmethod + def register_router_class(cls, target): + name = target.__dict__.get('NAME', None) + name = name if name else target.__class__.__name__ + cls.ROUTER_CLASSES[name] = target + return target + + @classmethod + def get_router_class(cls, name): + return cls.ROUTER_CLASSES.get(name, None) + + +class RouterFactory: + @classmethod + def new_router(cls, name, conn_mgr, **kwargs): + router_class = RouteManager.get_router_class(name) + assert router_class + return router_class(conn_mgr, **kwargs) + + +class RouterMixin: + def __init__(self, conn_mgr): + self.conn_mgr = conn_mgr + + def routing(self, table_name, metadata=None, **kwargs): + raise NotImplemented() + + +@RouteManager.register_router_class +class FileBasedHashRingRouter(RouterMixin): + NAME = 'FileBasedHashRingRouter' + + def __init__(self, conn_mgr, **kwargs): + super(FileBasedHashRingRouter, self).__init__(conn_mgr) + + def routing(self, table_name, metadata=None, **kwargs): + range_array = kwargs.pop('range_array', None) + return self._route(table_name, range_array, metadata, **kwargs) + + def _route(self, table_name, range_array, metadata=None, **kwargs): + # PXU TODO: Implement Thread-local Context + try: + table = db.Session.query(Tables).filter( + and_(Tables.table_id == table_name, + Tables.state != Tables.TO_DELETE)).first() + except sqlalchemy_exc.SQLAlchemyError as e: + raise exceptions.DBError(message=str(e), metadata=metadata) + + if not table: + raise exceptions.TableNotFoundError(table_name, metadata=metadata) + files = table.files_to_search(range_array) + + servers = self.conn_mgr.conn_names + logger.info('Available servers: {}'.format(servers)) + + ring = HashRing(servers) + + routing = {} + + for f in files: + target_host = ring.get_node(str(f.id)) + sub = routing.get(target_host, None) + if not sub: + routing[target_host] = {'table_id': table_name, 'file_ids': []} + routing[target_host]['file_ids'].append(str(f.id)) + + return routing diff --git a/mishards/server.py b/mishards/server.py index dcaacd0fbc..20be8f1746 100644 --- a/mishards/server.py +++ b/mishards/server.py @@ -22,17 +22,24 @@ class Server: self.error_handlers = {} self.exit_flag = False - def init_app(self, conn_mgr, tracer, discover, port=19530, max_workers=10, **kwargs): + def init_app(self, + conn_mgr, + tracer, + router, + discover, + port=19530, + max_workers=10, + **kwargs): self.port = int(port) self.conn_mgr = conn_mgr self.tracer = tracer + self.router = router self.discover = discover self.server_impl = grpc.server( thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers), options=[(cygrpc.ChannelArgKey.max_send_message_length, -1), - (cygrpc.ChannelArgKey.max_receive_message_length, -1)] - ) + (cygrpc.ChannelArgKey.max_receive_message_length, -1)]) self.server_impl = self.tracer.decorate(self.server_impl) @@ -43,8 +50,8 @@ class Server: url = urlparse(woserver) ip = socket.gethostbyname(url.hostname) socket.inet_pton(socket.AF_INET, ip) - self.conn_mgr.register('WOSERVER', - '{}://{}:{}'.format(url.scheme, ip, url.port or 80)) + self.conn_mgr.register( + 'WOSERVER', '{}://{}:{}'.format(url.scheme, ip, url.port or 80)) def register_pre_run_handler(self, func): logger.info('Regiterring {} into server pre_run_handlers'.format(func)) @@ -65,9 +72,11 @@ class Server: def errorhandler(self, exception): if inspect.isclass(exception) and issubclass(exception, Exception): + def wrapper(func): self.error_handlers[exception] = func return func + return wrapper return exception @@ -78,8 +87,12 @@ class Server: def start(self, port=None): handler_class = self.decorate_handler(ServiceHandler) - add_MilvusServiceServicer_to_server(handler_class(conn_mgr=self.conn_mgr, tracer=self.tracer), self.server_impl) - self.server_impl.add_insecure_port("[::]:{}".format(str(port or self._port))) + add_MilvusServiceServicer_to_server( + handler_class(conn_mgr=self.conn_mgr, + tracer=self.tracer, + router=self.router), self.server_impl) + self.server_impl.add_insecure_port("[::]:{}".format( + str(port or self._port))) self.server_impl.start() def run(self, port): diff --git a/mishards/service_handler.py b/mishards/service_handler.py index 1396466568..e26f2bfd74 100644 --- a/mishards/service_handler.py +++ b/mishards/service_handler.py @@ -3,9 +3,6 @@ import time import datetime from collections import defaultdict -from sqlalchemy import and_ -from sqlalchemy import exc as sqlalchemy_exc - from concurrent.futures import ThreadPoolExecutor from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 from milvus.grpc_gen.milvus_pb2 import TopKQueryResult @@ -15,8 +12,7 @@ from milvus.client import types as Types from mishards import (db, settings, exceptions) from mishards.grpc_utils import mark_grpc_method from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser -from mishards.models import Tables, TableFiles -from mishards.hash_ring import HashRing +from mishards import utilities logger = logging.getLogger(__name__) @@ -24,11 +20,12 @@ logger = logging.getLogger(__name__) class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): MAX_NPROBE = 2048 - def __init__(self, conn_mgr, tracer, *args, **kwargs): + def __init__(self, conn_mgr, tracer, router, *args, **kwargs): self.conn_mgr = conn_mgr self.table_meta = {} self.error_handlers = {} self.tracer = tracer + self.router = router def connection(self, metadata=None): conn = self.conn_mgr.conn('WOSERVER', metadata=metadata) @@ -43,56 +40,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): conn.on_connect(metadata=metadata) return conn.conn - def _format_date(self, start, end): - return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day, (end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day) - - def _range_to_date(self, range_obj, metadata=None): - try: - start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d') - end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d') - assert start < end - except (ValueError, AssertionError): - raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format( - range_obj.start_date, range_obj.end_date - ), metadata=metadata) - - return self._format_date(start, end) - - def _get_routing_file_ids(self, table_id, range_array, metadata=None): - # PXU TODO: Implement Thread-local Context - try: - table = db.Session.query(Tables).filter(and_( - Tables.table_id == table_id, - Tables.state != Tables.TO_DELETE - )).first() - except sqlalchemy_exc.SQLAlchemyError as e: - raise exceptions.DBError(message=str(e), metadata=metadata) - - if not table: - raise exceptions.TableNotFoundError(table_id, metadata=metadata) - files = table.files_to_search(range_array) - - servers = self.conn_mgr.conn_names - logger.info('Available servers: {}'.format(servers)) - - ring = HashRing(servers) - - routing = {} - - for f in files: - target_host = ring.get_node(str(f.id)) - sub = routing.get(target_host, None) - if not sub: - routing[target_host] = { - 'table_id': table_id, - 'file_ids': [] - } - routing[target_host]['file_ids'].append(str(f.id)) - - return routing - def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs): - status = status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success") + status = status_pb2.Status(error_code=status_pb2.SUCCESS, + reason="Success") if not files_n_topk_results: return status, [] @@ -103,10 +53,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): if isinstance(files_collection, tuple): status, _ = files_collection return status, [] - for request_pos, each_request_results in enumerate(files_collection.topk_query_result): - request_results[request_pos].extend(each_request_results.query_result_arrays) - request_results[request_pos] = sorted(request_results[request_pos], key=lambda x: x.distance, - reverse=reverse)[:topk] + for request_pos, each_request_results in enumerate( + files_collection.topk_query_result): + request_results[request_pos].extend( + each_request_results.query_result_arrays) + request_results[request_pos] = sorted( + request_results[request_pos], + key=lambda x: x.distance, + reverse=reverse)[:topk] calc_time = time.time() - calc_time logger.info('Merge takes {}'.format(calc_time)) @@ -120,15 +74,27 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return status, topk_query_result - def _do_query(self, context, table_id, table_meta, vectors, topk, nprobe, range_array=None, **kwargs): + def _do_query(self, + context, + table_id, + table_meta, + vectors, + topk, + nprobe, + range_array=None, + **kwargs): metadata = kwargs.get('metadata', None) - range_array = [self._range_to_date(r, metadata=metadata) for r in range_array] if range_array else None + range_array = [ + utilities.range_to_date(r, metadata=metadata) for r in range_array + ] if range_array else None routing = {} - p_span = None if self.tracer.empty else context.get_active_span().context - with self.tracer.start_span('get_routing', - child_of=p_span): - routing = self._get_routing_file_ids(table_id, range_array, metadata=metadata) + p_span = None if self.tracer.empty else context.get_active_span( + ).context + with self.tracer.start_span('get_routing', child_of=p_span): + routing = self.router.routing(table_id, + range_array=range_array, + metadata=metadata) logger.info('Routing: {}'.format(routing)) metadata = kwargs.get('metadata', None) @@ -139,42 +105,51 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): workers = settings.SEARCH_WORKER_SIZE def search(addr, query_params, vectors, topk, nprobe, **kwargs): - logger.info('Send Search Request: addr={};params={};nq={};topk={};nprobe={}'.format( - addr, query_params, len(vectors), topk, nprobe - )) + logger.info( + 'Send Search Request: addr={};params={};nq={};topk={};nprobe={}' + .format(addr, query_params, len(vectors), topk, nprobe)) conn = self.query_conn(addr, metadata=metadata) start = time.time() span = kwargs.get('span', None) - span = span if span else (None if self.tracer.empty else context.get_active_span().context) + span = span if span else (None if self.tracer.empty else + context.get_active_span().context) with self.tracer.start_span('search_{}'.format(addr), child_of=span): - ret = conn.search_vectors_in_files(table_name=query_params['table_id'], - file_ids=query_params['file_ids'], - query_records=vectors, - top_k=topk, - nprobe=nprobe, - lazy=True) + ret = conn.search_vectors_in_files( + table_name=query_params['table_id'], + file_ids=query_params['file_ids'], + query_records=vectors, + top_k=topk, + nprobe=nprobe, + lazy=True) end = time.time() logger.info('search_vectors_in_files takes: {}'.format(end - start)) all_topk_results.append(ret) - with self.tracer.start_span('do_search', - child_of=p_span) as span: + with self.tracer.start_span('do_search', child_of=p_span) as span: with ThreadPoolExecutor(max_workers=workers) as pool: for addr, params in routing.items(): - res = pool.submit(search, addr, params, vectors, topk, nprobe, span=span) + res = pool.submit(search, + addr, + params, + vectors, + topk, + nprobe, + span=span) rs.append(res) for res in rs: res.result() reverse = table_meta.metric_type == Types.MetricType.IP - with self.tracer.start_span('do_merge', - child_of=p_span): - return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata) + with self.tracer.start_span('do_merge', child_of=p_span): + return self._do_merge(all_topk_results, + topk, + reverse=reverse, + metadata=metadata) def _create_table(self, table_schema): return self.connection().create_table(table_schema) @@ -184,13 +159,15 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _status, _table_schema = Parser.parse_proto_TableSchema(request) if not _status.OK(): - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) logger.info('CreateTable {}'.format(_table_schema['table_name'])) _status = self._create_table(_table_schema) - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) def _has_table(self, table_name, metadata=None): return self.connection(metadata=metadata).has_table(table_name) @@ -200,20 +177,18 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _status, _table_name = Parser.parse_proto_TableName(request) if not _status.OK(): - return milvus_pb2.BoolReply( - status=status_pb2.Status(error_code=_status.code, reason=_status.message), - bool_reply=False - ) + return milvus_pb2.BoolReply(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), + bool_reply=False) logger.info('HasTable {}'.format(_table_name)) - _bool = self._has_table(_table_name, metadata={ - 'resp_class': milvus_pb2.BoolReply}) + _bool = self._has_table(_table_name, + metadata={'resp_class': milvus_pb2.BoolReply}) - return milvus_pb2.BoolReply( - status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"), - bool_reply=_bool - ) + return milvus_pb2.BoolReply(status=status_pb2.Status( + error_code=status_pb2.SUCCESS, reason="OK"), + bool_reply=_bool) def _delete_table(self, table_name): return self.connection().delete_table(table_name) @@ -223,13 +198,15 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _status, _table_name = Parser.parse_proto_TableName(request) if not _status.OK(): - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) logger.info('DropTable {}'.format(_table_name)) _status = self._delete_table(_table_name) - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) def _create_index(self, table_name, index): return self.connection().create_index(table_name, index) @@ -239,7 +216,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _status, unpacks = Parser.parse_proto_IndexParam(request) if not _status.OK(): - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) _table_name, _index = unpacks @@ -248,21 +226,22 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): # TODO: interface create_table incompleted _status = self._create_index(_table_name, _index) - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) def _add_vectors(self, param, metadata=None): - return self.connection(metadata=metadata).add_vectors(None, None, insert_param=param) + return self.connection(metadata=metadata).add_vectors( + None, None, insert_param=param) @mark_grpc_method def Insert(self, request, context): logger.info('Insert') # TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array' - _status, _ids = self._add_vectors(metadata={ - 'resp_class': milvus_pb2.VectorIds}, param=request) - return milvus_pb2.VectorIds( - status=status_pb2.Status(error_code=_status.code, reason=_status.message), - vector_id_array=_ids - ) + _status, _ids = self._add_vectors( + metadata={'resp_class': milvus_pb2.VectorIds}, param=request) + return milvus_pb2.VectorIds(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), + vector_id_array=_ids) @mark_grpc_method def Search(self, request, context): @@ -272,22 +251,23 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): topk = request.topk nprobe = request.nprobe - logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe)) + logger.info('Search {}: topk={} nprobe={}'.format( + table_name, topk, nprobe)) - metadata = { - 'resp_class': milvus_pb2.TopKQueryResultList - } + metadata = {'resp_class': milvus_pb2.TopKQueryResultList} if nprobe > self.MAX_NPROBE or nprobe <= 0: - raise exceptions.InvalidArgumentError(message='Invalid nprobe: {}'.format(nprobe), - metadata=metadata) + raise exceptions.InvalidArgumentError( + message='Invalid nprobe: {}'.format(nprobe), metadata=metadata) table_meta = self.table_meta.get(table_name, None) if not table_meta: - status, info = self.connection(metadata=metadata).describe_table(table_name) + status, info = self.connection( + metadata=metadata).describe_table(table_name) if not status.OK(): - raise exceptions.TableNotFoundError(table_name, metadata=metadata) + raise exceptions.TableNotFoundError(table_name, + metadata=metadata) self.table_meta[table_name] = info table_meta = info @@ -304,16 +284,22 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): query_range_array.append( Range(query_range.start_value, query_range.end_value)) - status, results = self._do_query(context, table_name, table_meta, query_record_array, topk, - nprobe, query_range_array, metadata=metadata) + status, results = self._do_query(context, + table_name, + table_meta, + query_record_array, + topk, + nprobe, + query_range_array, + metadata=metadata) now = time.time() logger.info('SearchVector takes: {}'.format(now - start)) topk_result_list = milvus_pb2.TopKQueryResultList( - status=status_pb2.Status(error_code=status.error_code, reason=status.reason), - topk_query_result=results - ) + status=status_pb2.Status(error_code=status.error_code, + reason=status.reason), + topk_query_result=results) return topk_result_list @mark_grpc_method @@ -328,16 +314,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _status, _table_name = Parser.parse_proto_TableName(request) if not _status.OK(): - return milvus_pb2.TableSchema( - status=status_pb2.Status(error_code=_status.code, reason=_status.message), - ) + return milvus_pb2.TableSchema(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), ) - metadata = { - 'resp_class': milvus_pb2.TableSchema - } + metadata = {'resp_class': milvus_pb2.TableSchema} logger.info('DescribeTable {}'.format(_table_name)) - _status, _table = self._describe_table(metadata=metadata, table_name=_table_name) + _status, _table = self._describe_table(metadata=metadata, + table_name=_table_name) if _status.OK(): return milvus_pb2.TableSchema( @@ -345,37 +329,38 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): index_file_size=_table.index_file_size, dimension=_table.dimension, metric_type=_table.metric_type, - status=status_pb2.Status(error_code=_status.code, reason=_status.message), + status=status_pb2.Status(error_code=_status.code, + reason=_status.message), ) return milvus_pb2.TableSchema( table_name=_table_name, - status=status_pb2.Status(error_code=_status.code, reason=_status.message), + status=status_pb2.Status(error_code=_status.code, + reason=_status.message), ) def _count_table(self, table_name, metadata=None): - return self.connection(metadata=metadata).get_table_row_count(table_name) + return self.connection( + metadata=metadata).get_table_row_count(table_name) @mark_grpc_method def CountTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) if not _status.OK(): - status = status_pb2.Status(error_code=_status.code, reason=_status.message) + status = status_pb2.Status(error_code=_status.code, + reason=_status.message) - return milvus_pb2.TableRowCount( - status=status - ) + return milvus_pb2.TableRowCount(status=status) logger.info('CountTable {}'.format(_table_name)) - metadata = { - 'resp_class': milvus_pb2.TableRowCount - } + metadata = {'resp_class': milvus_pb2.TableRowCount} _status, _count = self._count_table(_table_name, metadata=metadata) return milvus_pb2.TableRowCount( - status=status_pb2.Status(error_code=_status.code, reason=_status.message), + status=status_pb2.Status(error_code=_status.code, + reason=_status.message), table_row_count=_count if isinstance(_count, int) else -1) def _get_server_version(self, metadata=None): @@ -387,23 +372,20 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('Cmd: {}'.format(_cmd)) if not _status.OK(): - return milvus_pb2.StringReply( - status=status_pb2.Status(error_code=_status.code, reason=_status.message) - ) + return milvus_pb2.StringReply(status=status_pb2.Status( + error_code=_status.code, reason=_status.message)) - metadata = { - 'resp_class': milvus_pb2.StringReply - } + metadata = {'resp_class': milvus_pb2.StringReply} if _cmd == 'version': _status, _reply = self._get_server_version(metadata=metadata) else: - _status, _reply = self.connection(metadata=metadata).server_status() + _status, _reply = self.connection( + metadata=metadata).server_status() - return milvus_pb2.StringReply( - status=status_pb2.Status(error_code=_status.code, reason=_status.message), - string_reply=_reply - ) + return milvus_pb2.StringReply(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), + string_reply=_reply) def _show_tables(self, metadata=None): return self.connection(metadata=metadata).show_tables() @@ -411,18 +393,17 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): @mark_grpc_method def ShowTables(self, request, context): logger.info('ShowTables') - metadata = { - 'resp_class': milvus_pb2.TableName - } + metadata = {'resp_class': milvus_pb2.TableName} _status, _results = self._show_tables(metadata=metadata) - return milvus_pb2.TableNameList( - status=status_pb2.Status(error_code=_status.code, reason=_status.message), - table_names=_results - ) + return milvus_pb2.TableNameList(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), + table_names=_results) def _delete_by_range(self, table_name, start_date, end_date): - return self.connection().delete_vectors_by_range(table_name, start_date, end_date) + return self.connection().delete_vectors_by_range(table_name, + start_date, + end_date) @mark_grpc_method def DeleteByRange(self, request, context): @@ -430,13 +411,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): Parser.parse_proto_DeleteByRangeParam(request) if not _status.OK(): - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) _table_name, _start_date, _end_date = unpacks - logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date, _end_date)) + logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date, + _end_date)) _status = self._delete_by_range(_table_name, _start_date, _end_date) - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) def _preload_table(self, table_name): return self.connection().preload_table(table_name) @@ -446,11 +430,13 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _status, _table_name = Parser.parse_proto_TableName(request) if not _status.OK(): - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) logger.info('PreloadTable {}'.format(_table_name)) _status = self._preload_table(_table_name) - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) def _describe_index(self, table_name, metadata=None): return self.connection(metadata=metadata).describe_index(table_name) @@ -460,21 +446,22 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _status, _table_name = Parser.parse_proto_TableName(request) if not _status.OK(): - return milvus_pb2.IndexParam( - status=status_pb2.Status(error_code=_status.code, reason=_status.message) - ) + return milvus_pb2.IndexParam(status=status_pb2.Status( + error_code=_status.code, reason=_status.message)) - metadata = { - 'resp_class': milvus_pb2.IndexParam - } + metadata = {'resp_class': milvus_pb2.IndexParam} logger.info('DescribeIndex {}'.format(_table_name)) - _status, _index_param = self._describe_index(table_name=_table_name, metadata=metadata) + _status, _index_param = self._describe_index(table_name=_table_name, + metadata=metadata) - _index = milvus_pb2.Index(index_type=_index_param._index_type, nlist=_index_param._nlist) + _index = milvus_pb2.Index(index_type=_index_param._index_type, + nlist=_index_param._nlist) - return milvus_pb2.IndexParam(status=status_pb2.Status(error_code=_status.code, reason=_status.message), - table_name=_table_name, index=_index) + return milvus_pb2.IndexParam(status=status_pb2.Status( + error_code=_status.code, reason=_status.message), + table_name=_table_name, + index=_index) def _drop_index(self, table_name): return self.connection().drop_index(table_name) @@ -484,8 +471,10 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _status, _table_name = Parser.parse_proto_TableName(request) if not _status.OK(): - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) logger.info('DropIndex {}'.format(_table_name)) _status = self._drop_index(_table_name) - return status_pb2.Status(error_code=_status.code, reason=_status.message) + return status_pb2.Status(error_code=_status.code, + reason=_status.message) diff --git a/mishards/settings.py b/mishards/settings.py index c9b62717d4..5e81a1a8ad 100644 --- a/mishards/settings.py +++ b/mishards/settings.py @@ -73,12 +73,14 @@ class DefaultConfig: SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI') SQL_ECHO = env.bool('SQL_ECHO', False) TRACING_TYPE = env.str('TRACING_TYPE', '') + ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_NAME', 'FileBasedHashRingRouter') class TestingConfig(DefaultConfig): SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_TEST_URI') SQL_ECHO = env.bool('SQL_TEST_ECHO', False) TRACING_TYPE = env.str('TRACING_TEST_TYPE', '') + ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_TEST_NAME', 'FileBasedHashRingRouter') if __name__ == '__main__': diff --git a/mishards/utilities.py b/mishards/utilities.py new file mode 100644 index 0000000000..c08d0d42df --- /dev/null +++ b/mishards/utilities.py @@ -0,0 +1,20 @@ +import datetime +from mishards import exceptions + + +def format_date(self, start, end): + return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day, + (end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day) + + +def range_to_date(self, range_obj, metadata=None): + try: + start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d') + end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d') + assert start < end + except (ValueError, AssertionError): + raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format( + range_obj.start_date, range_obj.end_date), + metadata=metadata) + + return self.format_date(start, end)