diff --git a/__init__.py b/__init__.py deleted file mode 100644 index 7db5c41bd0..0000000000 --- a/__init__.py +++ /dev/null @@ -1 +0,0 @@ -import settings diff --git a/mishards/__init__.py b/mishards/__init__.py new file mode 100644 index 0000000000..700dd4238c --- /dev/null +++ b/mishards/__init__.py @@ -0,0 +1,6 @@ +import settings +from connections import ConnectionMgr +connect_mgr = ConnectionMgr() + +from server import Server +grpc_server = Server(conn_mgr=connect_mgr) diff --git a/connections.py b/mishards/connections.py similarity index 99% rename from connections.py rename to mishards/connections.py index c52a1c5f85..06d5f3ff16 100644 --- a/connections.py +++ b/mishards/connections.py @@ -89,7 +89,7 @@ class ConnectionMgr: threaded = { threading.get_ident() : this_conn } - c[name] = threaded + self.conns[name] = threaded return this_conn tid = threading.get_ident() diff --git a/exception_codes.py b/mishards/exception_codes.py similarity index 100% rename from exception_codes.py rename to mishards/exception_codes.py diff --git a/exceptions.py b/mishards/exceptions.py similarity index 100% rename from exceptions.py rename to mishards/exceptions.py diff --git a/mishards/grpc_utils/__init__.py b/mishards/grpc_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mishards/grpc_utils/grpc_args_parser.py b/mishards/grpc_utils/grpc_args_parser.py new file mode 100644 index 0000000000..c8dc9d71d9 --- /dev/null +++ b/mishards/grpc_utils/grpc_args_parser.py @@ -0,0 +1,101 @@ +from milvus import Status +from functools import wraps + + +def error_status(func): + @wraps(func) + def inner(*args, **kwargs): + try: + results = func(*args, **kwargs) + except Exception as e: + return Status(code=Status.UNEXPECTED_ERROR, message=str(e)), None + + return Status(code=0, message="Success"), results + + return inner + + +class GrpcArgsParser(object): + + @classmethod + @error_status + def parse_proto_TableSchema(cls, param): + _table_schema = { + 'table_name': param.table_name.table_name, + 'dimension': param.dimension, + 'index_file_size': param.index_file_size, + 'metric_type': param.metric_type + } + + return _table_schema + + @classmethod + @error_status + def parse_proto_TableName(cls, param): + return param.table_name + + @classmethod + @error_status + def parse_proto_Index(cls, param): + _index = { + 'index_type': param.index_type, + 'nlist': param.nlist + } + + return _index + + @classmethod + @error_status + def parse_proto_IndexParam(cls, param): + _table_name = param.table_name.table_name + _status, _index = cls.parse_proto_Index(param.index) + + if not _status.OK(): + raise Exception("Argument parse error") + + return _table_name, _index + + @classmethod + @error_status + def parse_proto_Command(cls, param): + _cmd = param.cmd + + return _cmd + + @classmethod + @error_status + def parse_proto_Range(cls, param): + _start_value = param.start_value + _end_value = param.end_value + + return _start_value, _end_value + + @classmethod + @error_status + def parse_proto_RowRecord(cls, param): + return list(param.vector_data) + + @classmethod + @error_status + def parse_proto_SearchParam(cls, param): + _table_name = param.table_name + _topk = param.topk + _nprobe = param.nprobe + _status, _range = cls.parse_proto_Range(param.query_range_array) + + if not _status.OK(): + raise Exception("Argument parse error") + + _row_record = param.query_record_array + + return _table_name, _row_record, _range, _topk + + @classmethod + @error_status + def parse_proto_DeleteByRangeParam(cls, param): + _table_name = param.table_name + _range = param.range + _start_value = _range.start_value + _end_value = _range.end_value + + return _table_name, _start_value, _end_value diff --git a/mishards/grpc_utils/grpc_args_wrapper.py b/mishards/grpc_utils/grpc_args_wrapper.py new file mode 100644 index 0000000000..a864b1e400 --- /dev/null +++ b/mishards/grpc_utils/grpc_args_wrapper.py @@ -0,0 +1,4 @@ +# class GrpcArgsWrapper(object): + + # @classmethod + # def proto_TableName(cls): \ No newline at end of file diff --git a/mishards/main.py b/mishards/main.py new file mode 100644 index 0000000000..0185e6ac1d --- /dev/null +++ b/mishards/main.py @@ -0,0 +1,14 @@ +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import settings +from mishards import connect_mgr, grpc_server as server + +def main(): + connect_mgr.register('WOSERVER', settings.WOSERVER) + server.run(port=settings.SERVER_PORT) + return 0 + +if __name__ == '__main__': + sys.exit(main()) diff --git a/mishards/server.py b/mishards/server.py new file mode 100644 index 0000000000..59ea7db46b --- /dev/null +++ b/mishards/server.py @@ -0,0 +1,47 @@ +import logging +import grpc +import time +from concurrent import futures +from grpc._cython import cygrpc +from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server +from service_handler import ServiceHandler +import settings + +logger = logging.getLogger(__name__) + + +class Server: + def __init__(self, conn_mgr, port=19530, max_workers=10, **kwargs): + self.exit_flag = False + self.port = int(port) + self.conn_mgr = conn_mgr + self.server_impl = grpc.server( + thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers), + options=[(cygrpc.ChannelArgKey.max_send_message_length, -1), + (cygrpc.ChannelArgKey.max_receive_message_length, -1)] + ) + + def start(self, port=None): + add_MilvusServiceServicer_to_server(ServiceHandler(conn_mgr=self.conn_mgr), self.server_impl) + self.server_impl.add_insecure_port("[::]:{}".format(str(port or self._port))) + self.server_impl.start() + + def run(self, port): + logger.info('Milvus server start ......') + port = port or self.port + + self.start(port) + logger.info('Successfully') + logger.info('Listening on port {}'.format(port)) + + try: + while not self.exit_flag: + time.sleep(5) + except KeyboardInterrupt: + self.stop() + + def stop(self): + logger.info('Server is shuting down ......') + self.exit_flag = True + self.server.stop(0) + logger.info('Server is closed') diff --git a/mishards/service_handler.py b/mishards/service_handler.py new file mode 100644 index 0000000000..ead8d14d88 --- /dev/null +++ b/mishards/service_handler.py @@ -0,0 +1,327 @@ +import logging +from contextlib import contextmanager +from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 + +from grpc_utils.grpc_args_parser import GrpcArgsParser as Parser + +logger = logging.getLogger(__name__) + + +class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): + def __init__(self, conn_mgr, *args, **kwargs): + self.conn_mgr = conn_mgr + self.table_meta = {} + + @property + def connection(self): + conn = self.conn_mgr.conn('WOSERVER') + if conn: + conn.on_connect() + return conn.conn + + def CreateTable(self, request, context): + _status, _table_schema = Parser.parse_proto_TableSchema(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, reason=_status.message) + + logger.info('CreateTable {}'.format(_table_schema['table_name'])) + + _status = self.connection.create_table(_table_schema) + + return status_pb2.Status(error_code=_status.code, reason=_status.message) + + def HasTable(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + return milvus_pb2.BoolReply( + status=status_pb2.Status(error_code=_status.code, reason=_status.message), + bool_reply=False + ) + + logger.info('HasTable {}'.format(_table_name)) + + _bool = self.connection.has_table(_table_name) + + return milvus_pb2.BoolReply( + status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"), + bool_reply=_bool + ) + + def DropTable(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, reason=_status.message) + + logger.info('DropTable {}'.format(_table_name)) + + _status = self.connection.delete_table(_table_name) + + return status_pb2.Status(error_code=_status.code, reason=_status.message) + + def CreateIndex(self, request, context): + _status, unpacks = Parser.parse_proto_IndexParam(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, reason=_status.message) + + _table_name, _index = unpacks + + logger.info('CreateIndex {}'.format(_table_name)) + + # TODO: interface create_table incompleted + _status = self.connection.create_index(_table_name, _index) + + return status_pb2.Status(error_code=_status.code, reason=_status.message) + + def Insert(self, request, context): + logger.info('Insert') + # TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array' + _status, _ids = self.connection.add_vectors(None, None, insert_param=request) + return milvus_pb2.VectorIds( + status=status_pb2.Status(error_code=_status.code, reason=_status.message), + vector_id_array=_ids + ) + + def Search(self, request, context): + + try: + table_name = request.table_name + + topk = request.topk + nprobe = request.nprobe + + logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe)) + + if nprobe > 2048 or nprobe <= 0: + raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe)) + + table_meta = self.table_meta.get(table_name, None) + if not table_meta: + status, info = self.connection.describe_table(table_name) + if not status.OK(): + raise TableNotFoundException(table_name) + + self.table_meta[table_name] = info + table_meta = info + + start = time.time() + + query_record_array = [] + + for query_record in request.query_record_array: + query_record_array.append(list(query_record.vector_data)) + + query_range_array = [] + for query_range in request.query_range_array: + query_range_array.append( + Range(query_range.start_value, query_range.end_value)) + except (TableNotFoundException, exceptions.GRPCInvlidArgument) as exc: + return milvus_pb2.TopKQueryResultList( + status=status_pb2.Status(error_code=exc.code, reason=exc.message) + ) + except Exception as e: + return milvus_pb2.TopKQueryResultList( + status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e)) + ) + + try: + results = workflow.query_vectors(table_name, table_meta, query_record_array, topk, + nprobe, query_range_array) + except (exceptions.GRPCQueryInvalidRangeException, TableNotFoundException) as exc: + return milvus_pb2.TopKQueryResultList( + status=status_pb2.Status(error_code=exc.code, reason=exc.message) + ) + except exceptions.ServiceNotFoundException as exc: + return milvus_pb2.TopKQueryResultList( + status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=exc.message) + ) + except Exception as e: + logger.error(e) + results = workflow.query_vectors(table_name, table_meta, query_record_array, + topk, nprobe, query_range_array) + + now = time.time() + logger.info('SearchVector Ends @{}'.format(now)) + logger.info('SearchVector takes: {}'.format(now - start)) + + topk_result_list = milvus_pb2.TopKQueryResultList( + status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success"), + topk_query_result=results + ) + return topk_result_list + + def SearchInFiles(self, request, context): + try: + file_id_array = list(request.file_id_array) + search_param = request.search_param + table_name = search_param.table_name + topk = search_param.topk + nprobe = search_param.nprobe + + query_record_array = [] + + for query_record in search_param.query_record_array: + query_record_array.append(list(query_record)) + + query_range_array = [] + for query_range in search_param.query_range_array: + query_range_array.append("") + except Exception as e: + milvus_pb2.TopKQueryResultList( + status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e)), + ) + + res = search_vector_in_files.delay(table_name=table_name, + file_id_array=file_id_array, + query_record_array=query_record_array, + query_range_array=query_range_array, + topk=topk, + nprobe=nprobe) + status, result = res.get(timeout=1) + + if not status.OK(): + raise ThriftException(code=status.code, reason=status.message) + res = TopKQueryResult() + for top_k_query_results in result: + res.query_result_arrays.append([QueryResult(id=qr.id, distance=qr.distance) + for qr in top_k_query_results]) + return res + + def DescribeTable(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + table_name = milvus_pb2.TableName( + status=status_pb2.Status(error_code=_status.code, reason=_status.message) + ) + return milvus_pb2.TableSchema( + table_name=table_name + ) + + logger.info('DescribeTable {}'.format(_table_name)) + _status, _table = self.connection.describe_table(_table_name) + + if _status.OK(): + _grpc_table_name = milvus_pb2.TableName( + status=status_pb2.Status(error_code=_status.code, reason=_status.message), + table_name=_table.table_name + ) + + return milvus_pb2.TableSchema( + table_name=_grpc_table_name, + index_file_size=_table.index_file_size, + dimension=_table.dimension, + metric_type=_table.metric_type + ) + + return milvus_pb2.TableSchema( + table_name=milvus_pb2.TableName( + status=status_pb2.Status(error_code=_status.code, reason=_status.message) + ) + ) + + def CountTable(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + status = status_pb2.Status(error_code=_status.code, reason=_status.message) + + return milvus_pb2.TableRowCount( + status=status + ) + + logger.info('CountTable {}'.format(_table_name)) + + _status, _count = self.connection.get_table_row_count(_table_name) + + return milvus_pb2.TableRowCount( + status=status_pb2.Status(error_code=_status.code, reason=_status.message), + table_row_count=_count if isinstance(_count, int) else -1) + + def Cmd(self, request, context): + _status, _cmd = Parser.parse_proto_Command(request) + logger.info('Cmd: {}'.format(_cmd)) + + if not _status.OK(): + return milvus_pb2.StringReply( + status_pb2.Status(error_code=_status.code, reason=_status.message) + ) + + if _cmd == 'version': + _status, _reply = self.connection.server_version() + else: + _status, _reply = self.connection.server_status() + + return milvus_pb2.StringReply( + status=status_pb2.Status(error_code=_status.code, reason=_status.message), + string_reply=_reply + ) + + def ShowTables(self, request, context): + logger.info('ShowTables') + _status, _results = self.connection.show_tables() + + if not _status.OK(): + _results = [] + + for _result in _results: + yield milvus_pb2.TableName( + status=status_pb2.Status(error_code=_status.code, reason=_status.message), + table_name=_result + ) + + def DeleteByRange(self, request, context): + _status, unpacks = \ + Parser.parse_proto_DeleteByRangeParam(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, reason=_status.message) + + _table_name, _start_date, _end_date = unpacks + + logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date, _end_date)) + _status = self.connection.delete_vectors_by_range(_table_name, _start_date, _end_date) + return status_pb2.Status(error_code=_status.code, reason=_status.message) + + def PreloadTable(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, reason=_status.message) + + logger.info('PreloadTable {}'.format(_table_name)) + _status = self.connection.preload_table(_table_name) + return status_pb2.Status(error_code=_status.code, reason=_status.message) + + def DescribeIndex(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + return milvus_pb2.IndexParam( + table_name=milvus_pb2.TableName( + status=status_pb2.Status(error_code=_status.code, reason=_status.message) + ) + ) + + logger.info('DescribeIndex {}'.format(_table_name)) + _status, _index_param = self.connection.describe_index(_table_name) + + _index = milvus_pb2.Index(index_type=_index_param._index_type, nlist=_index_param._nlist) + _tablename = milvus_pb2.TableName( + status=status_pb2.Status(error_code=_status.code, reason=_status.message), + table_name=_table_name) + + return milvus_pb2.IndexParam(table_name=_tablename, index=_index) + + def DropIndex(self, request, context): + _status, _table_name = Parser.parse_proto_TableName(request) + + if not _status.OK(): + return status_pb2.Status(error_code=_status.code, reason=_status.message) + + logger.info('DropIndex {}'.format(_table_name)) + _status = self.connection.drop_index(_table_name) + return status_pb2.Status(error_code=_status.code, reason=_status.message) diff --git a/settings.py b/mishards/settings.py similarity index 90% rename from settings.py rename to mishards/settings.py index 4ad00e66cb..0566cf066f 100644 --- a/settings.py +++ b/mishards/settings.py @@ -22,6 +22,8 @@ config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE) TIMEOUT = env.int('TIMEOUT', 60) MAX_RETRY = env.int('MAX_RETRY', 3) +SERVER_PORT = env.int('SERVER_PORT', 19530) +WOSERVER = env.str('WOSERVER') if __name__ == '__main__': import logging diff --git a/utils/__init__.py b/mishards/utils/__init__.py similarity index 100% rename from utils/__init__.py rename to mishards/utils/__init__.py diff --git a/utils/logger_helper.py b/mishards/utils/logger_helper.py similarity index 100% rename from utils/logger_helper.py rename to mishards/utils/logger_helper.py diff --git a/service_handler.py b/service_handler.py deleted file mode 100644 index d5018a54d8..0000000000 --- a/service_handler.py +++ /dev/null @@ -1,11 +0,0 @@ -import logging - -import grpco -from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 - -logger = logging.getLogger(__name__) - - -class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): - def __init__(self, connections, *args, **kwargs): - self.connections = self.connections