import pdb import random import logging import json import time, datetime from multiprocessing import Process import numpy import sklearn.preprocessing from milvus import Milvus, IndexType, MetricType logger = logging.getLogger("milvus_ann_acc.client") SERVER_HOST_DEFAULT = "127.0.0.1" SERVER_PORT_DEFAULT = 19530 def time_wrapper(func): """ This decorator prints the execution time for the decorated function. """ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2))) return result return wrapper class MilvusClient(object): def __init__(self, table_name=None, ip=None, port=None): self._milvus = Milvus() self._table_name = table_name try: if not ip: self._milvus.connect( host = SERVER_HOST_DEFAULT, port = SERVER_PORT_DEFAULT) else: self._milvus.connect( host = ip, port = port) except Exception as e: raise e def __str__(self): return 'Milvus table %s' % self._table_name def check_status(self, status): if not status.OK(): logger.error(status.message) raise Exception("Status not ok") def create_table(self, table_name, dimension, index_file_size, metric_type): if not self._table_name: self._table_name = table_name if metric_type == "l2": metric_type = MetricType.L2 elif metric_type == "ip": metric_type = MetricType.IP else: logger.error("Not supported metric_type: %s" % metric_type) self._metric_type = metric_type create_param = {'table_name': table_name, 'dimension': dimension, 'index_file_size': index_file_size, "metric_type": metric_type} status = self._milvus.create_table(create_param) self.check_status(status) @time_wrapper def insert(self, X, ids): if self._metric_type == MetricType.IP: logger.info("Set normalize for metric_type: Inner Product") X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') X = X.astype(numpy.float32) status, result = self._milvus.add_vectors(self._table_name, X.tolist(), ids=ids) self.check_status(status) return status, result @time_wrapper def create_index(self, index_type, nlist): if index_type == "flat": index_type = IndexType.FLAT elif index_type == "ivf_flat": index_type = IndexType.IVFLAT elif index_type == "ivf_sq8": index_type = IndexType.IVF_SQ8 elif index_type == "ivf_sq8h": index_type = IndexType.IVF_SQ8H elif index_type == "mix_nsg": index_type = IndexType.MIX_NSG index_params = { "index_type": index_type, "nlist": nlist, } logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params))) status = self._milvus.create_index(self._table_name, index=index_params, timeout=6*3600) self.check_status(status) def describe_index(self): return self._milvus.describe_index(self._table_name) def drop_index(self): logger.info("Drop index: %s" % self._table_name) return self._milvus.drop_index(self._table_name) @time_wrapper def query(self, X, top_k, nprobe): if self._metric_type == MetricType.IP: logger.info("Set normalize for metric_type: Inner Product") X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') X = X.astype(numpy.float32) status, results = self._milvus.search_vectors(self._table_name, top_k, nprobe, X.tolist()) self.check_status(status) # logger.info(results[0]) ids = [] for result in results: tmp_ids = [] for item in result: tmp_ids.append(item.id) ids.append(tmp_ids) return ids def count(self): return self._milvus.get_table_row_count(self._table_name)[1] def delete(self, timeout=60): logger.info("Start delete table: %s" % self._table_name) self._milvus.delete_table(self._table_name) i = 0 while i < timeout: if self.count(): time.sleep(1) i = i + 1 else: break if i >= timeout: logger.error("Delete table timeout") def describe(self): return self._milvus.describe_table(self._table_name) def exists_table(self): return self._milvus.has_table(self._table_name) @time_wrapper def preload_table(self): return self._milvus.preload_table(self._table_name)