mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 11:21:52 +08:00
add router in impl
This commit is contained in:
parent
560c4310ae
commit
a3409be0dc
@ -27,7 +27,10 @@ def create_app(testing_config=None):
|
||||
tracer = TracerFactory.new_tracer(config.TRACING_TYPE, settings.TracingConfig,
|
||||
span_decorator=GrpcSpanDecorator())
|
||||
|
||||
grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, discover=discover)
|
||||
from mishards.routings import RouterFactory
|
||||
router = RouterFactory.new_router(config.ROUTER_CLASS_NAME, connect_mgr)
|
||||
|
||||
grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, router=router, discover=discover)
|
||||
|
||||
from mishards import exception_handlers
|
||||
|
||||
|
||||
81
mishards/routings.py
Normal file
81
mishards/routings.py
Normal file
@ -0,0 +1,81 @@
|
||||
import logging
|
||||
from sqlalchemy import exc as sqlalchemy_exc
|
||||
from sqlalchemy import and_
|
||||
|
||||
from mishards import exceptions, db
|
||||
from mishards.hash_ring import HashRing
|
||||
from mishards.models import Tables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RouteManager:
|
||||
ROUTER_CLASSES = {}
|
||||
|
||||
@classmethod
|
||||
def register_router_class(cls, target):
|
||||
name = target.__dict__.get('NAME', None)
|
||||
name = name if name else target.__class__.__name__
|
||||
cls.ROUTER_CLASSES[name] = target
|
||||
return target
|
||||
|
||||
@classmethod
|
||||
def get_router_class(cls, name):
|
||||
return cls.ROUTER_CLASSES.get(name, None)
|
||||
|
||||
|
||||
class RouterFactory:
|
||||
@classmethod
|
||||
def new_router(cls, name, conn_mgr, **kwargs):
|
||||
router_class = RouteManager.get_router_class(name)
|
||||
assert router_class
|
||||
return router_class(conn_mgr, **kwargs)
|
||||
|
||||
|
||||
class RouterMixin:
|
||||
def __init__(self, conn_mgr):
|
||||
self.conn_mgr = conn_mgr
|
||||
|
||||
def routing(self, table_name, metadata=None, **kwargs):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
@RouteManager.register_router_class
|
||||
class FileBasedHashRingRouter(RouterMixin):
|
||||
NAME = 'FileBasedHashRingRouter'
|
||||
|
||||
def __init__(self, conn_mgr, **kwargs):
|
||||
super(FileBasedHashRingRouter, self).__init__(conn_mgr)
|
||||
|
||||
def routing(self, table_name, metadata=None, **kwargs):
|
||||
range_array = kwargs.pop('range_array', None)
|
||||
return self._route(table_name, range_array, metadata, **kwargs)
|
||||
|
||||
def _route(self, table_name, range_array, metadata=None, **kwargs):
|
||||
# PXU TODO: Implement Thread-local Context
|
||||
try:
|
||||
table = db.Session.query(Tables).filter(
|
||||
and_(Tables.table_id == table_name,
|
||||
Tables.state != Tables.TO_DELETE)).first()
|
||||
except sqlalchemy_exc.SQLAlchemyError as e:
|
||||
raise exceptions.DBError(message=str(e), metadata=metadata)
|
||||
|
||||
if not table:
|
||||
raise exceptions.TableNotFoundError(table_name, metadata=metadata)
|
||||
files = table.files_to_search(range_array)
|
||||
|
||||
servers = self.conn_mgr.conn_names
|
||||
logger.info('Available servers: {}'.format(servers))
|
||||
|
||||
ring = HashRing(servers)
|
||||
|
||||
routing = {}
|
||||
|
||||
for f in files:
|
||||
target_host = ring.get_node(str(f.id))
|
||||
sub = routing.get(target_host, None)
|
||||
if not sub:
|
||||
routing[target_host] = {'table_id': table_name, 'file_ids': []}
|
||||
routing[target_host]['file_ids'].append(str(f.id))
|
||||
|
||||
return routing
|
||||
@ -22,17 +22,24 @@ class Server:
|
||||
self.error_handlers = {}
|
||||
self.exit_flag = False
|
||||
|
||||
def init_app(self, conn_mgr, tracer, discover, port=19530, max_workers=10, **kwargs):
|
||||
def init_app(self,
|
||||
conn_mgr,
|
||||
tracer,
|
||||
router,
|
||||
discover,
|
||||
port=19530,
|
||||
max_workers=10,
|
||||
**kwargs):
|
||||
self.port = int(port)
|
||||
self.conn_mgr = conn_mgr
|
||||
self.tracer = tracer
|
||||
self.router = router
|
||||
self.discover = discover
|
||||
|
||||
self.server_impl = grpc.server(
|
||||
thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
|
||||
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
|
||||
(cygrpc.ChannelArgKey.max_receive_message_length, -1)]
|
||||
)
|
||||
(cygrpc.ChannelArgKey.max_receive_message_length, -1)])
|
||||
|
||||
self.server_impl = self.tracer.decorate(self.server_impl)
|
||||
|
||||
@ -43,8 +50,8 @@ class Server:
|
||||
url = urlparse(woserver)
|
||||
ip = socket.gethostbyname(url.hostname)
|
||||
socket.inet_pton(socket.AF_INET, ip)
|
||||
self.conn_mgr.register('WOSERVER',
|
||||
'{}://{}:{}'.format(url.scheme, ip, url.port or 80))
|
||||
self.conn_mgr.register(
|
||||
'WOSERVER', '{}://{}:{}'.format(url.scheme, ip, url.port or 80))
|
||||
|
||||
def register_pre_run_handler(self, func):
|
||||
logger.info('Regiterring {} into server pre_run_handlers'.format(func))
|
||||
@ -65,9 +72,11 @@ class Server:
|
||||
|
||||
def errorhandler(self, exception):
|
||||
if inspect.isclass(exception) and issubclass(exception, Exception):
|
||||
|
||||
def wrapper(func):
|
||||
self.error_handlers[exception] = func
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
return exception
|
||||
|
||||
@ -78,8 +87,12 @@ class Server:
|
||||
|
||||
def start(self, port=None):
|
||||
handler_class = self.decorate_handler(ServiceHandler)
|
||||
add_MilvusServiceServicer_to_server(handler_class(conn_mgr=self.conn_mgr, tracer=self.tracer), self.server_impl)
|
||||
self.server_impl.add_insecure_port("[::]:{}".format(str(port or self._port)))
|
||||
add_MilvusServiceServicer_to_server(
|
||||
handler_class(conn_mgr=self.conn_mgr,
|
||||
tracer=self.tracer,
|
||||
router=self.router), self.server_impl)
|
||||
self.server_impl.add_insecure_port("[::]:{}".format(
|
||||
str(port or self._port)))
|
||||
self.server_impl.start()
|
||||
|
||||
def run(self, port):
|
||||
|
||||
@ -3,9 +3,6 @@ import time
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import exc as sqlalchemy_exc
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
|
||||
from milvus.grpc_gen.milvus_pb2 import TopKQueryResult
|
||||
@ -15,8 +12,7 @@ from milvus.client import types as Types
|
||||
from mishards import (db, settings, exceptions)
|
||||
from mishards.grpc_utils import mark_grpc_method
|
||||
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
|
||||
from mishards.models import Tables, TableFiles
|
||||
from mishards.hash_ring import HashRing
|
||||
from mishards import utilities
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -24,11 +20,12 @@ logger = logging.getLogger(__name__)
|
||||
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
MAX_NPROBE = 2048
|
||||
|
||||
def __init__(self, conn_mgr, tracer, *args, **kwargs):
|
||||
def __init__(self, conn_mgr, tracer, router, *args, **kwargs):
|
||||
self.conn_mgr = conn_mgr
|
||||
self.table_meta = {}
|
||||
self.error_handlers = {}
|
||||
self.tracer = tracer
|
||||
self.router = router
|
||||
|
||||
def connection(self, metadata=None):
|
||||
conn = self.conn_mgr.conn('WOSERVER', metadata=metadata)
|
||||
@ -43,56 +40,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
conn.on_connect(metadata=metadata)
|
||||
return conn.conn
|
||||
|
||||
def _format_date(self, start, end):
|
||||
return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day, (end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day)
|
||||
|
||||
def _range_to_date(self, range_obj, metadata=None):
|
||||
try:
|
||||
start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d')
|
||||
end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d')
|
||||
assert start < end
|
||||
except (ValueError, AssertionError):
|
||||
raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format(
|
||||
range_obj.start_date, range_obj.end_date
|
||||
), metadata=metadata)
|
||||
|
||||
return self._format_date(start, end)
|
||||
|
||||
def _get_routing_file_ids(self, table_id, range_array, metadata=None):
|
||||
# PXU TODO: Implement Thread-local Context
|
||||
try:
|
||||
table = db.Session.query(Tables).filter(and_(
|
||||
Tables.table_id == table_id,
|
||||
Tables.state != Tables.TO_DELETE
|
||||
)).first()
|
||||
except sqlalchemy_exc.SQLAlchemyError as e:
|
||||
raise exceptions.DBError(message=str(e), metadata=metadata)
|
||||
|
||||
if not table:
|
||||
raise exceptions.TableNotFoundError(table_id, metadata=metadata)
|
||||
files = table.files_to_search(range_array)
|
||||
|
||||
servers = self.conn_mgr.conn_names
|
||||
logger.info('Available servers: {}'.format(servers))
|
||||
|
||||
ring = HashRing(servers)
|
||||
|
||||
routing = {}
|
||||
|
||||
for f in files:
|
||||
target_host = ring.get_node(str(f.id))
|
||||
sub = routing.get(target_host, None)
|
||||
if not sub:
|
||||
routing[target_host] = {
|
||||
'table_id': table_id,
|
||||
'file_ids': []
|
||||
}
|
||||
routing[target_host]['file_ids'].append(str(f.id))
|
||||
|
||||
return routing
|
||||
|
||||
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
|
||||
status = status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success")
|
||||
status = status_pb2.Status(error_code=status_pb2.SUCCESS,
|
||||
reason="Success")
|
||||
if not files_n_topk_results:
|
||||
return status, []
|
||||
|
||||
@ -103,10 +53,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
if isinstance(files_collection, tuple):
|
||||
status, _ = files_collection
|
||||
return status, []
|
||||
for request_pos, each_request_results in enumerate(files_collection.topk_query_result):
|
||||
request_results[request_pos].extend(each_request_results.query_result_arrays)
|
||||
request_results[request_pos] = sorted(request_results[request_pos], key=lambda x: x.distance,
|
||||
reverse=reverse)[:topk]
|
||||
for request_pos, each_request_results in enumerate(
|
||||
files_collection.topk_query_result):
|
||||
request_results[request_pos].extend(
|
||||
each_request_results.query_result_arrays)
|
||||
request_results[request_pos] = sorted(
|
||||
request_results[request_pos],
|
||||
key=lambda x: x.distance,
|
||||
reverse=reverse)[:topk]
|
||||
|
||||
calc_time = time.time() - calc_time
|
||||
logger.info('Merge takes {}'.format(calc_time))
|
||||
@ -120,15 +74,27 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
return status, topk_query_result
|
||||
|
||||
def _do_query(self, context, table_id, table_meta, vectors, topk, nprobe, range_array=None, **kwargs):
|
||||
def _do_query(self,
|
||||
context,
|
||||
table_id,
|
||||
table_meta,
|
||||
vectors,
|
||||
topk,
|
||||
nprobe,
|
||||
range_array=None,
|
||||
**kwargs):
|
||||
metadata = kwargs.get('metadata', None)
|
||||
range_array = [self._range_to_date(r, metadata=metadata) for r in range_array] if range_array else None
|
||||
range_array = [
|
||||
utilities.range_to_date(r, metadata=metadata) for r in range_array
|
||||
] if range_array else None
|
||||
|
||||
routing = {}
|
||||
p_span = None if self.tracer.empty else context.get_active_span().context
|
||||
with self.tracer.start_span('get_routing',
|
||||
child_of=p_span):
|
||||
routing = self._get_routing_file_ids(table_id, range_array, metadata=metadata)
|
||||
p_span = None if self.tracer.empty else context.get_active_span(
|
||||
).context
|
||||
with self.tracer.start_span('get_routing', child_of=p_span):
|
||||
routing = self.router.routing(table_id,
|
||||
range_array=range_array,
|
||||
metadata=metadata)
|
||||
logger.info('Routing: {}'.format(routing))
|
||||
|
||||
metadata = kwargs.get('metadata', None)
|
||||
@ -139,42 +105,51 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
workers = settings.SEARCH_WORKER_SIZE
|
||||
|
||||
def search(addr, query_params, vectors, topk, nprobe, **kwargs):
|
||||
logger.info('Send Search Request: addr={};params={};nq={};topk={};nprobe={}'.format(
|
||||
addr, query_params, len(vectors), topk, nprobe
|
||||
))
|
||||
logger.info(
|
||||
'Send Search Request: addr={};params={};nq={};topk={};nprobe={}'
|
||||
.format(addr, query_params, len(vectors), topk, nprobe))
|
||||
|
||||
conn = self.query_conn(addr, metadata=metadata)
|
||||
start = time.time()
|
||||
span = kwargs.get('span', None)
|
||||
span = span if span else (None if self.tracer.empty else context.get_active_span().context)
|
||||
span = span if span else (None if self.tracer.empty else
|
||||
context.get_active_span().context)
|
||||
|
||||
with self.tracer.start_span('search_{}'.format(addr),
|
||||
child_of=span):
|
||||
ret = conn.search_vectors_in_files(table_name=query_params['table_id'],
|
||||
file_ids=query_params['file_ids'],
|
||||
query_records=vectors,
|
||||
top_k=topk,
|
||||
nprobe=nprobe,
|
||||
lazy=True)
|
||||
ret = conn.search_vectors_in_files(
|
||||
table_name=query_params['table_id'],
|
||||
file_ids=query_params['file_ids'],
|
||||
query_records=vectors,
|
||||
top_k=topk,
|
||||
nprobe=nprobe,
|
||||
lazy=True)
|
||||
end = time.time()
|
||||
logger.info('search_vectors_in_files takes: {}'.format(end - start))
|
||||
|
||||
all_topk_results.append(ret)
|
||||
|
||||
with self.tracer.start_span('do_search',
|
||||
child_of=p_span) as span:
|
||||
with self.tracer.start_span('do_search', child_of=p_span) as span:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
for addr, params in routing.items():
|
||||
res = pool.submit(search, addr, params, vectors, topk, nprobe, span=span)
|
||||
res = pool.submit(search,
|
||||
addr,
|
||||
params,
|
||||
vectors,
|
||||
topk,
|
||||
nprobe,
|
||||
span=span)
|
||||
rs.append(res)
|
||||
|
||||
for res in rs:
|
||||
res.result()
|
||||
|
||||
reverse = table_meta.metric_type == Types.MetricType.IP
|
||||
with self.tracer.start_span('do_merge',
|
||||
child_of=p_span):
|
||||
return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata)
|
||||
with self.tracer.start_span('do_merge', child_of=p_span):
|
||||
return self._do_merge(all_topk_results,
|
||||
topk,
|
||||
reverse=reverse,
|
||||
metadata=metadata)
|
||||
|
||||
def _create_table(self, table_schema):
|
||||
return self.connection().create_table(table_schema)
|
||||
@ -184,13 +159,15 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
_status, _table_schema = Parser.parse_proto_TableSchema(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
logger.info('CreateTable {}'.format(_table_schema['table_name']))
|
||||
|
||||
_status = self._create_table(_table_schema)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
def _has_table(self, table_name, metadata=None):
|
||||
return self.connection(metadata=metadata).has_table(table_name)
|
||||
@ -200,20 +177,18 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.BoolReply(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
bool_reply=False
|
||||
)
|
||||
return milvus_pb2.BoolReply(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message),
|
||||
bool_reply=False)
|
||||
|
||||
logger.info('HasTable {}'.format(_table_name))
|
||||
|
||||
_bool = self._has_table(_table_name, metadata={
|
||||
'resp_class': milvus_pb2.BoolReply})
|
||||
_bool = self._has_table(_table_name,
|
||||
metadata={'resp_class': milvus_pb2.BoolReply})
|
||||
|
||||
return milvus_pb2.BoolReply(
|
||||
status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"),
|
||||
bool_reply=_bool
|
||||
)
|
||||
return milvus_pb2.BoolReply(status=status_pb2.Status(
|
||||
error_code=status_pb2.SUCCESS, reason="OK"),
|
||||
bool_reply=_bool)
|
||||
|
||||
def _delete_table(self, table_name):
|
||||
return self.connection().delete_table(table_name)
|
||||
@ -223,13 +198,15 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
logger.info('DropTable {}'.format(_table_name))
|
||||
|
||||
_status = self._delete_table(_table_name)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
def _create_index(self, table_name, index):
|
||||
return self.connection().create_index(table_name, index)
|
||||
@ -239,7 +216,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
_status, unpacks = Parser.parse_proto_IndexParam(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
_table_name, _index = unpacks
|
||||
|
||||
@ -248,21 +226,22 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
# TODO: interface create_table incompleted
|
||||
_status = self._create_index(_table_name, _index)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
def _add_vectors(self, param, metadata=None):
|
||||
return self.connection(metadata=metadata).add_vectors(None, None, insert_param=param)
|
||||
return self.connection(metadata=metadata).add_vectors(
|
||||
None, None, insert_param=param)
|
||||
|
||||
@mark_grpc_method
|
||||
def Insert(self, request, context):
|
||||
logger.info('Insert')
|
||||
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
|
||||
_status, _ids = self._add_vectors(metadata={
|
||||
'resp_class': milvus_pb2.VectorIds}, param=request)
|
||||
return milvus_pb2.VectorIds(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
vector_id_array=_ids
|
||||
)
|
||||
_status, _ids = self._add_vectors(
|
||||
metadata={'resp_class': milvus_pb2.VectorIds}, param=request)
|
||||
return milvus_pb2.VectorIds(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message),
|
||||
vector_id_array=_ids)
|
||||
|
||||
@mark_grpc_method
|
||||
def Search(self, request, context):
|
||||
@ -272,22 +251,23 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
topk = request.topk
|
||||
nprobe = request.nprobe
|
||||
|
||||
logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe))
|
||||
logger.info('Search {}: topk={} nprobe={}'.format(
|
||||
table_name, topk, nprobe))
|
||||
|
||||
metadata = {
|
||||
'resp_class': milvus_pb2.TopKQueryResultList
|
||||
}
|
||||
metadata = {'resp_class': milvus_pb2.TopKQueryResultList}
|
||||
|
||||
if nprobe > self.MAX_NPROBE or nprobe <= 0:
|
||||
raise exceptions.InvalidArgumentError(message='Invalid nprobe: {}'.format(nprobe),
|
||||
metadata=metadata)
|
||||
raise exceptions.InvalidArgumentError(
|
||||
message='Invalid nprobe: {}'.format(nprobe), metadata=metadata)
|
||||
|
||||
table_meta = self.table_meta.get(table_name, None)
|
||||
|
||||
if not table_meta:
|
||||
status, info = self.connection(metadata=metadata).describe_table(table_name)
|
||||
status, info = self.connection(
|
||||
metadata=metadata).describe_table(table_name)
|
||||
if not status.OK():
|
||||
raise exceptions.TableNotFoundError(table_name, metadata=metadata)
|
||||
raise exceptions.TableNotFoundError(table_name,
|
||||
metadata=metadata)
|
||||
|
||||
self.table_meta[table_name] = info
|
||||
table_meta = info
|
||||
@ -304,16 +284,22 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
query_range_array.append(
|
||||
Range(query_range.start_value, query_range.end_value))
|
||||
|
||||
status, results = self._do_query(context, table_name, table_meta, query_record_array, topk,
|
||||
nprobe, query_range_array, metadata=metadata)
|
||||
status, results = self._do_query(context,
|
||||
table_name,
|
||||
table_meta,
|
||||
query_record_array,
|
||||
topk,
|
||||
nprobe,
|
||||
query_range_array,
|
||||
metadata=metadata)
|
||||
|
||||
now = time.time()
|
||||
logger.info('SearchVector takes: {}'.format(now - start))
|
||||
|
||||
topk_result_list = milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=status.error_code, reason=status.reason),
|
||||
topk_query_result=results
|
||||
)
|
||||
status=status_pb2.Status(error_code=status.error_code,
|
||||
reason=status.reason),
|
||||
topk_query_result=results)
|
||||
return topk_result_list
|
||||
|
||||
@mark_grpc_method
|
||||
@ -328,16 +314,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.TableSchema(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
)
|
||||
return milvus_pb2.TableSchema(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message), )
|
||||
|
||||
metadata = {
|
||||
'resp_class': milvus_pb2.TableSchema
|
||||
}
|
||||
metadata = {'resp_class': milvus_pb2.TableSchema}
|
||||
|
||||
logger.info('DescribeTable {}'.format(_table_name))
|
||||
_status, _table = self._describe_table(metadata=metadata, table_name=_table_name)
|
||||
_status, _table = self._describe_table(metadata=metadata,
|
||||
table_name=_table_name)
|
||||
|
||||
if _status.OK():
|
||||
return milvus_pb2.TableSchema(
|
||||
@ -345,37 +329,38 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
index_file_size=_table.index_file_size,
|
||||
dimension=_table.dimension,
|
||||
metric_type=_table.metric_type,
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
status=status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message),
|
||||
)
|
||||
|
||||
return milvus_pb2.TableSchema(
|
||||
table_name=_table_name,
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
status=status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message),
|
||||
)
|
||||
|
||||
def _count_table(self, table_name, metadata=None):
|
||||
return self.connection(metadata=metadata).get_table_row_count(table_name)
|
||||
return self.connection(
|
||||
metadata=metadata).get_table_row_count(table_name)
|
||||
|
||||
@mark_grpc_method
|
||||
def CountTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
status = status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
status = status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
return milvus_pb2.TableRowCount(
|
||||
status=status
|
||||
)
|
||||
return milvus_pb2.TableRowCount(status=status)
|
||||
|
||||
logger.info('CountTable {}'.format(_table_name))
|
||||
|
||||
metadata = {
|
||||
'resp_class': milvus_pb2.TableRowCount
|
||||
}
|
||||
metadata = {'resp_class': milvus_pb2.TableRowCount}
|
||||
_status, _count = self._count_table(_table_name, metadata=metadata)
|
||||
|
||||
return milvus_pb2.TableRowCount(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
status=status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message),
|
||||
table_row_count=_count if isinstance(_count, int) else -1)
|
||||
|
||||
def _get_server_version(self, metadata=None):
|
||||
@ -387,23 +372,20 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
logger.info('Cmd: {}'.format(_cmd))
|
||||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.StringReply(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
)
|
||||
return milvus_pb2.StringReply(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message))
|
||||
|
||||
metadata = {
|
||||
'resp_class': milvus_pb2.StringReply
|
||||
}
|
||||
metadata = {'resp_class': milvus_pb2.StringReply}
|
||||
|
||||
if _cmd == 'version':
|
||||
_status, _reply = self._get_server_version(metadata=metadata)
|
||||
else:
|
||||
_status, _reply = self.connection(metadata=metadata).server_status()
|
||||
_status, _reply = self.connection(
|
||||
metadata=metadata).server_status()
|
||||
|
||||
return milvus_pb2.StringReply(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
string_reply=_reply
|
||||
)
|
||||
return milvus_pb2.StringReply(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message),
|
||||
string_reply=_reply)
|
||||
|
||||
def _show_tables(self, metadata=None):
|
||||
return self.connection(metadata=metadata).show_tables()
|
||||
@ -411,18 +393,17 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
@mark_grpc_method
|
||||
def ShowTables(self, request, context):
|
||||
logger.info('ShowTables')
|
||||
metadata = {
|
||||
'resp_class': milvus_pb2.TableName
|
||||
}
|
||||
metadata = {'resp_class': milvus_pb2.TableName}
|
||||
_status, _results = self._show_tables(metadata=metadata)
|
||||
|
||||
return milvus_pb2.TableNameList(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_names=_results
|
||||
)
|
||||
return milvus_pb2.TableNameList(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message),
|
||||
table_names=_results)
|
||||
|
||||
def _delete_by_range(self, table_name, start_date, end_date):
|
||||
return self.connection().delete_vectors_by_range(table_name, start_date, end_date)
|
||||
return self.connection().delete_vectors_by_range(table_name,
|
||||
start_date,
|
||||
end_date)
|
||||
|
||||
@mark_grpc_method
|
||||
def DeleteByRange(self, request, context):
|
||||
@ -430,13 +411,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
Parser.parse_proto_DeleteByRangeParam(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
_table_name, _start_date, _end_date = unpacks
|
||||
|
||||
logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date, _end_date))
|
||||
logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date,
|
||||
_end_date))
|
||||
_status = self._delete_by_range(_table_name, _start_date, _end_date)
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
def _preload_table(self, table_name):
|
||||
return self.connection().preload_table(table_name)
|
||||
@ -446,11 +430,13 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
logger.info('PreloadTable {}'.format(_table_name))
|
||||
_status = self._preload_table(_table_name)
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
def _describe_index(self, table_name, metadata=None):
|
||||
return self.connection(metadata=metadata).describe_index(table_name)
|
||||
@ -460,21 +446,22 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.IndexParam(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
)
|
||||
return milvus_pb2.IndexParam(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message))
|
||||
|
||||
metadata = {
|
||||
'resp_class': milvus_pb2.IndexParam
|
||||
}
|
||||
metadata = {'resp_class': milvus_pb2.IndexParam}
|
||||
|
||||
logger.info('DescribeIndex {}'.format(_table_name))
|
||||
_status, _index_param = self._describe_index(table_name=_table_name, metadata=metadata)
|
||||
_status, _index_param = self._describe_index(table_name=_table_name,
|
||||
metadata=metadata)
|
||||
|
||||
_index = milvus_pb2.Index(index_type=_index_param._index_type, nlist=_index_param._nlist)
|
||||
_index = milvus_pb2.Index(index_type=_index_param._index_type,
|
||||
nlist=_index_param._nlist)
|
||||
|
||||
return milvus_pb2.IndexParam(status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_name=_table_name, index=_index)
|
||||
return milvus_pb2.IndexParam(status=status_pb2.Status(
|
||||
error_code=_status.code, reason=_status.message),
|
||||
table_name=_table_name,
|
||||
index=_index)
|
||||
|
||||
def _drop_index(self, table_name):
|
||||
return self.connection().drop_index(table_name)
|
||||
@ -484,8 +471,10 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
logger.info('DropIndex {}'.format(_table_name))
|
||||
_status = self._drop_index(_table_name)
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
return status_pb2.Status(error_code=_status.code,
|
||||
reason=_status.message)
|
||||
|
||||
@ -73,12 +73,14 @@ class DefaultConfig:
|
||||
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI')
|
||||
SQL_ECHO = env.bool('SQL_ECHO', False)
|
||||
TRACING_TYPE = env.str('TRACING_TYPE', '')
|
||||
ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_NAME', 'FileBasedHashRingRouter')
|
||||
|
||||
|
||||
class TestingConfig(DefaultConfig):
|
||||
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_TEST_URI')
|
||||
SQL_ECHO = env.bool('SQL_TEST_ECHO', False)
|
||||
TRACING_TYPE = env.str('TRACING_TEST_TYPE', '')
|
||||
ROUTER_CLASS_NAME = env.str('ROUTER_CLASS_TEST_NAME', 'FileBasedHashRingRouter')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
20
mishards/utilities.py
Normal file
20
mishards/utilities.py
Normal file
@ -0,0 +1,20 @@
|
||||
import datetime
|
||||
from mishards import exceptions
|
||||
|
||||
|
||||
def format_date(self, start, end):
|
||||
return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day,
|
||||
(end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day)
|
||||
|
||||
|
||||
def range_to_date(self, range_obj, metadata=None):
|
||||
try:
|
||||
start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d')
|
||||
end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d')
|
||||
assert start < end
|
||||
except (ValueError, AssertionError):
|
||||
raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format(
|
||||
range_obj.start_date, range_obj.end_date),
|
||||
metadata=metadata)
|
||||
|
||||
return self.format_date(start, end)
|
||||
Loading…
x
Reference in New Issue
Block a user