mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
update for models
This commit is contained in:
parent
099317edee
commit
5646362157
13
manager.py
13
manager.py
@ -1,5 +1,6 @@
|
||||
import fire
|
||||
from mishards import db
|
||||
from sqlalchemy import and_
|
||||
|
||||
class DBHandler:
|
||||
@classmethod
|
||||
@ -10,5 +11,17 @@ class DBHandler:
|
||||
def drop_all(cls):
|
||||
db.drop_all()
|
||||
|
||||
@classmethod
|
||||
def fun(cls, tid):
|
||||
from mishards.factories import TablesFactory, TableFilesFactory, Tables
|
||||
f = db.Session.query(Tables).filter(and_(
|
||||
Tables.table_id==tid,
|
||||
Tables.state!=Tables.TO_DELETE)
|
||||
).first()
|
||||
print(f)
|
||||
|
||||
# f1 = TableFilesFactory()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(DBHandler)
|
||||
|
||||
@ -2,7 +2,7 @@ from mishards import settings
|
||||
|
||||
from mishards.db_base import DB
|
||||
db = DB()
|
||||
db.init_db(uri=settings.SQLALCHEMY_DATABASE_URI)
|
||||
db.init_db(uri=settings.SQLALCHEMY_DATABASE_URI, echo=settings.SQL_ECHO)
|
||||
|
||||
from mishards.connections import ConnectionMgr
|
||||
connect_mgr = ConnectionMgr()
|
||||
|
||||
@ -1,15 +1,20 @@
|
||||
import logging
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DB:
|
||||
Model = declarative_base()
|
||||
def __init__(self, uri=None):
|
||||
uri and self.init_db(uri)
|
||||
def __init__(self, uri=None, echo=False):
|
||||
self.echo = echo
|
||||
uri and self.init_db(uri, echo)
|
||||
|
||||
def init_db(self, uri):
|
||||
def init_db(self, uri, echo=False):
|
||||
self.engine = create_engine(uri, pool_size=100, pool_recycle=5, pool_timeout=30,
|
||||
pool_pre_ping=True,
|
||||
echo=echo,
|
||||
max_overflow=0)
|
||||
self.uri = uri
|
||||
session = sessionmaker()
|
||||
|
||||
49
mishards/factories.py
Normal file
49
mishards/factories.py
Normal file
@ -0,0 +1,49 @@
|
||||
import time
|
||||
import datetime
|
||||
import random
|
||||
import factory
|
||||
from factory.alchemy import SQLAlchemyModelFactory
|
||||
from faker import Faker
|
||||
from faker.providers import BaseProvider
|
||||
|
||||
from mishards import db
|
||||
from mishards.models import Tables, TableFiles
|
||||
|
||||
class FakerProvider(BaseProvider):
|
||||
def this_date(self):
|
||||
t = datetime.datetime.today()
|
||||
return (t.year - 1900) * 10000 + (t.month-1)*100 + t.day
|
||||
|
||||
factory.Faker.add_provider(FakerProvider)
|
||||
|
||||
class TablesFactory(SQLAlchemyModelFactory):
|
||||
class Meta:
|
||||
model = Tables
|
||||
sqlalchemy_session = db.Session
|
||||
sqlalchemy_session_persistence = 'commit'
|
||||
|
||||
id = factory.Faker('random_number', digits=16, fix_len=True)
|
||||
table_id = factory.Faker('uuid4')
|
||||
state = factory.Faker('random_element', elements=(0,1,2,3))
|
||||
dimension = factory.Faker('random_element', elements=(256,512))
|
||||
created_on = int(time.time())
|
||||
index_file_size = 0
|
||||
engine_type = factory.Faker('random_element', elements=(0,1,2,3))
|
||||
metric_type = factory.Faker('random_element', elements=(0,1))
|
||||
nlist = 16384
|
||||
|
||||
class TableFilesFactory(SQLAlchemyModelFactory):
|
||||
class Meta:
|
||||
model = TableFiles
|
||||
sqlalchemy_session = db.Session
|
||||
sqlalchemy_session_persistence = 'commit'
|
||||
|
||||
id = factory.Faker('random_number', digits=16, fix_len=True)
|
||||
table = factory.SubFactory(TablesFactory)
|
||||
engine_type = factory.Faker('random_element', elements=(0,1,2,3))
|
||||
file_id = factory.Faker('uuid4')
|
||||
file_type = factory.Faker('random_element', elements=(0,1,2,3,4))
|
||||
file_size = factory.Faker('random_number')
|
||||
updated_time = int(time.time())
|
||||
created_on = int(time.time())
|
||||
date = factory.Faker('this_date')
|
||||
150
mishards/hash_ring.py
Normal file
150
mishards/hash_ring.py
Normal file
@ -0,0 +1,150 @@
|
||||
import math
|
||||
import sys
|
||||
from bisect import bisect
|
||||
|
||||
if sys.version_info >= (2, 5):
|
||||
import hashlib
|
||||
md5_constructor = hashlib.md5
|
||||
else:
|
||||
import md5
|
||||
md5_constructor = md5.new
|
||||
|
||||
class HashRing(object):
|
||||
|
||||
def __init__(self, nodes=None, weights=None):
|
||||
"""`nodes` is a list of objects that have a proper __str__ representation.
|
||||
`weights` is dictionary that sets weights to the nodes. The default
|
||||
weight is that all nodes are equal.
|
||||
"""
|
||||
self.ring = dict()
|
||||
self._sorted_keys = []
|
||||
|
||||
self.nodes = nodes
|
||||
|
||||
if not weights:
|
||||
weights = {}
|
||||
self.weights = weights
|
||||
|
||||
self._generate_circle()
|
||||
|
||||
def _generate_circle(self):
|
||||
"""Generates the circle.
|
||||
"""
|
||||
total_weight = 0
|
||||
for node in self.nodes:
|
||||
total_weight += self.weights.get(node, 1)
|
||||
|
||||
for node in self.nodes:
|
||||
weight = 1
|
||||
|
||||
if node in self.weights:
|
||||
weight = self.weights.get(node)
|
||||
|
||||
factor = math.floor((40*len(self.nodes)*weight) / total_weight);
|
||||
|
||||
for j in range(0, int(factor)):
|
||||
b_key = self._hash_digest( '%s-%s' % (node, j) )
|
||||
|
||||
for i in range(0, 3):
|
||||
key = self._hash_val(b_key, lambda x: x+i*4)
|
||||
self.ring[key] = node
|
||||
self._sorted_keys.append(key)
|
||||
|
||||
self._sorted_keys.sort()
|
||||
|
||||
def get_node(self, string_key):
|
||||
"""Given a string key a corresponding node in the hash ring is returned.
|
||||
|
||||
If the hash ring is empty, `None` is returned.
|
||||
"""
|
||||
pos = self.get_node_pos(string_key)
|
||||
if pos is None:
|
||||
return None
|
||||
return self.ring[ self._sorted_keys[pos] ]
|
||||
|
||||
def get_node_pos(self, string_key):
|
||||
"""Given a string key a corresponding node in the hash ring is returned
|
||||
along with it's position in the ring.
|
||||
|
||||
If the hash ring is empty, (`None`, `None`) is returned.
|
||||
"""
|
||||
if not self.ring:
|
||||
return None
|
||||
|
||||
key = self.gen_key(string_key)
|
||||
|
||||
nodes = self._sorted_keys
|
||||
pos = bisect(nodes, key)
|
||||
|
||||
if pos == len(nodes):
|
||||
return 0
|
||||
else:
|
||||
return pos
|
||||
|
||||
def iterate_nodes(self, string_key, distinct=True):
|
||||
"""Given a string key it returns the nodes as a generator that can hold the key.
|
||||
|
||||
The generator iterates one time through the ring
|
||||
starting at the correct position.
|
||||
|
||||
if `distinct` is set, then the nodes returned will be unique,
|
||||
i.e. no virtual copies will be returned.
|
||||
"""
|
||||
if not self.ring:
|
||||
yield None, None
|
||||
|
||||
returned_values = set()
|
||||
def distinct_filter(value):
|
||||
if str(value) not in returned_values:
|
||||
returned_values.add(str(value))
|
||||
return value
|
||||
|
||||
pos = self.get_node_pos(string_key)
|
||||
for key in self._sorted_keys[pos:]:
|
||||
val = distinct_filter(self.ring[key])
|
||||
if val:
|
||||
yield val
|
||||
|
||||
for i, key in enumerate(self._sorted_keys):
|
||||
if i < pos:
|
||||
val = distinct_filter(self.ring[key])
|
||||
if val:
|
||||
yield val
|
||||
|
||||
def gen_key(self, key):
|
||||
"""Given a string key it returns a long value,
|
||||
this long value represents a place on the hash ring.
|
||||
|
||||
md5 is currently used because it mixes well.
|
||||
"""
|
||||
b_key = self._hash_digest(key)
|
||||
return self._hash_val(b_key, lambda x: x)
|
||||
|
||||
def _hash_val(self, b_key, entry_fn):
|
||||
return (( b_key[entry_fn(3)] << 24)
|
||||
|(b_key[entry_fn(2)] << 16)
|
||||
|(b_key[entry_fn(1)] << 8)
|
||||
| b_key[entry_fn(0)] )
|
||||
|
||||
def _hash_digest(self, key):
|
||||
m = md5_constructor()
|
||||
key = key.encode()
|
||||
m.update(key)
|
||||
return m.digest()
|
||||
|
||||
if __name__ == '__main__':
|
||||
from collections import defaultdict
|
||||
servers = ['192.168.0.246:11212',
|
||||
'192.168.0.247:11212',
|
||||
'192.168.0.248:11212',
|
||||
'192.168.0.249:11212']
|
||||
|
||||
ring = HashRing(servers)
|
||||
keys = ['{}'.format(i) for i in range(100)]
|
||||
mapped = defaultdict(list)
|
||||
for k in keys:
|
||||
server = ring.get_node(k)
|
||||
mapped[server].append(k)
|
||||
|
||||
for k,v in mapped.items():
|
||||
print(k, v)
|
||||
@ -32,8 +32,8 @@ class TableFiles(db.Model):
|
||||
date = Column(Integer)
|
||||
|
||||
table = relationship(
|
||||
'Table',
|
||||
primaryjoin='and_(foreign(TableFile.table_id) == Table.table_id)',
|
||||
'Tables',
|
||||
primaryjoin='and_(foreign(TableFiles.table_id) == Tables.table_id)',
|
||||
backref=backref('files', uselist=True, lazy='dynamic')
|
||||
)
|
||||
|
||||
@ -57,15 +57,15 @@ class Tables(db.Model):
|
||||
|
||||
def files_to_search(self, date_range=None):
|
||||
cond = or_(
|
||||
TableFile.file_type==TableFile.FILE_TYPE_RAW,
|
||||
TableFile.file_type==TableFile.FILE_TYPE_TO_INDEX,
|
||||
TableFile.file_type==TableFile.FILE_TYPE_INDEX,
|
||||
TableFiles.file_type==TableFiles.FILE_TYPE_RAW,
|
||||
TableFiles.file_type==TableFiles.FILE_TYPE_TO_INDEX,
|
||||
TableFiles.file_type==TableFiles.FILE_TYPE_INDEX,
|
||||
)
|
||||
if date_range:
|
||||
cond = and_(
|
||||
cond,
|
||||
or_(
|
||||
and_(TableFile.date>=d[0], TableFile.date<d[1]) for d in date_range
|
||||
and_(TableFiles.date>=d[0], TableFiles.date<d[1]) for d in date_range
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -250,8 +250,8 @@ class ServiceFounder(object):
|
||||
self.listener.daemon = True
|
||||
self.listener.start()
|
||||
self.event_handler.start()
|
||||
while self.listener.at_start_up:
|
||||
time.sleep(1)
|
||||
# while self.listener.at_start_up:
|
||||
# time.sleep(1)
|
||||
|
||||
self.pod_heartbeater.start()
|
||||
|
||||
|
||||
@ -4,14 +4,17 @@ import datetime
|
||||
from contextlib import contextmanager
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import and_
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
|
||||
from milvus.grpc_gen.milvus_pb2 import TopKQueryResult
|
||||
from milvus.client import types
|
||||
|
||||
from mishards import (settings, exceptions)
|
||||
from mishards import (db, settings, exceptions)
|
||||
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
|
||||
from mishards.models import Tables, TableFiles
|
||||
from mishards.hash_ring import HashRing
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -53,12 +56,34 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
return self._format_date(start, end)
|
||||
|
||||
def _get_routing_file_ids(self, table_id, range_array):
|
||||
return {
|
||||
'milvus-ro-servers-0': {
|
||||
'table_id': table_id,
|
||||
'file_ids': [123]
|
||||
table = db.Session.query(Tables).filter(and_(
|
||||
Tables.table_id==table_id,
|
||||
Tables.state!=Tables.TO_DELETE
|
||||
)).first()
|
||||
logger.error(table)
|
||||
|
||||
if not table:
|
||||
raise exceptions.TableNotFoundError(table_id)
|
||||
files = table.files_to_search(range_array)
|
||||
|
||||
servers = self.conn_mgr.conn_names
|
||||
logger.info('Available servers: {}'.format(servers))
|
||||
|
||||
ring = HashRing(servers)
|
||||
|
||||
routing = {}
|
||||
|
||||
for f in files:
|
||||
target_host = ring.get_node(str(f.id))
|
||||
sub = routing.get(target_host, None)
|
||||
if not sub:
|
||||
routing[target_host] = {
|
||||
'table_id': table_id,
|
||||
'file_ids': []
|
||||
}
|
||||
}
|
||||
routing[target_host]['file_ids'].append(str(f.id))
|
||||
|
||||
return routing
|
||||
|
||||
def _do_merge(self, files_n_topk_results, topk, reverse=False):
|
||||
if not files_n_topk_results:
|
||||
@ -88,7 +113,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
def _do_query(self, table_id, table_meta, vectors, topk, nprobe, range_array=None, **kwargs):
|
||||
range_array = [self._range_to_date(r) for r in range_array] if range_array else None
|
||||
routing = self._get_routing_file_ids(table_id, range_array)
|
||||
logger.debug(routing)
|
||||
logger.info('Routing: {}'.format(routing))
|
||||
|
||||
rs = []
|
||||
all_topk_results = []
|
||||
|
||||
@ -19,6 +19,7 @@ from mishards.utils.logger_helper import config
|
||||
config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI')
|
||||
SQL_ECHO = env.bool('SQL_ECHO', False)
|
||||
|
||||
TIMEOUT = env.int('TIMEOUT', 60)
|
||||
MAX_RETRY = env.int('MAX_RETRY', 3)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user