Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2024-08-09 19:17:55 +08:00
parent 5f67e46fe1
commit f09c310cf4

View File

@ -0,0 +1,609 @@
import time
import random
import threading
from concurrent.futures import ThreadPoolExecutor
import json
from datetime import datetime
import requests
from pymilvus import connections, Collection, DataType, FieldSchema, CollectionSchema, utility
from loguru import logger
import matplotlib.pyplot as plt
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd
import sys
from collections import deque
from statistics import mean
logger.remove()
logger.add(sink=sys.stdout, level="DEBUG")
def create_interactive_plot(data, count_data, upstream_data):
# Convert timestamps to datetime
data['datetime'] = pd.to_datetime(data['timestamp'], unit='s')
count_data['datetime'] = pd.to_datetime(count_data['timestamp'], unit='s')
upstream_data['datetime'] = pd.to_datetime(upstream_data['timestamp'], unit='s')
# Create a figure with subplots
fig = make_subplots(rows=5, cols=1,
subplot_titles=("Source and Target Collection Count over Time",
"Insert and Sync Throughput over Time",
"Latency over Time",
"Upstream Operations QPS",
"Upstream Operations Latency"),
vertical_spacing=0.08)
# Plot 1: Source and Target Collection Count
fig.add_trace(go.Scatter(x=count_data['datetime'], y=count_data['source_count'],
mode='lines', name='Source Count'),
row=1, col=1)
fig.add_trace(go.Scatter(x=count_data['datetime'], y=count_data['target_count'],
mode='lines', name='Target Count'),
row=1, col=1)
# Plot 2: Insert and Sync Throughput
fig.add_trace(go.Scatter(x=data['datetime'], y=data['insert_throughput'],
mode='lines', name='Insert Throughput'),
row=2, col=1)
fig.add_trace(go.Scatter(x=data['datetime'], y=data['sync_throughput'],
mode='lines', name='Sync Throughput'),
row=2, col=1)
# Plot 3: Latency
fig.add_trace(go.Scatter(x=data['datetime'], y=data['real_time_latency'],
mode='lines', name='Real-time Latency'),
row=3, col=1)
# Plot 4: Upstream Operations QPS
fig.add_trace(go.Scatter(x=upstream_data['datetime'], y=upstream_data['delete_qps'],
mode='lines', name='Delete QPS'),
row=4, col=1)
fig.add_trace(go.Scatter(x=upstream_data['datetime'], y=upstream_data['query_qps'],
mode='lines', name='Query QPS'),
row=4, col=1)
fig.add_trace(go.Scatter(x=upstream_data['datetime'], y=upstream_data['search_qps'],
mode='lines', name='Search QPS'),
row=4, col=1)
# Plot 5: Upstream Operations Latency
fig.add_trace(go.Scatter(x=upstream_data['datetime'], y=upstream_data['delete_latency'],
mode='lines', name='Delete Latency'),
row=5, col=1)
fig.add_trace(go.Scatter(x=upstream_data['datetime'], y=upstream_data['query_latency'],
mode='lines', name='Query Latency'),
row=5, col=1)
fig.add_trace(go.Scatter(x=upstream_data['datetime'], y=upstream_data['search_latency'],
mode='lines', name='Search Latency'),
row=5, col=1)
# Update layout
fig.update_layout(height=2000, width=1000, title_text="Milvus CDC Performance Metrics")
# Update x-axes to show real datetime
for i in range(1, 6):
fig.update_xaxes(title_text="Time", row=i, col=1,
tickformat="%Y-%m-%d %H:%M:%S")
fig.update_yaxes(title_text="Entity Count", row=1, col=1)
fig.update_yaxes(title_text="Throughput (entities/second)", row=2, col=1)
fig.update_yaxes(title_text="Latency (seconds)", row=3, col=1)
fig.update_yaxes(title_text="QPS", row=4, col=1)
fig.update_yaxes(title_text="Latency (seconds)", row=5, col=1)
return fig
class MilvusCDCPerformance:
def __init__(self, source_alias, target_alias, cdc_host):
self.delete_entities_ts = []
self.cdc_paused = None
self.source_alias = source_alias
self.target_alias = target_alias
self.cdc_host = cdc_host
self.source_collection = None
self.target_collection = None
self.insert_count = 0
self.sync_count = 0
self.source_count = 0
self.target_count = 0
self.insert_lock = threading.Lock()
self.sync_lock = threading.Lock()
self.latest_insert_ts = 0
self.latest_query_ts = 0
self.stop_query = False
self.latencies = []
self.latest_insert_status = {
"latest_ts": 0,
"latest_count": 0
}
self.start_time = time.time()
self.last_report_time = time.time()
self.last_source_count = 0
self.last_target_count = 0
# New attributes for time series data
self.time_series_data = {
'timestamp': [],
'insert_throughput': [],
'sync_throughput': [],
'avg_latency': [],
'real_time_latency': [] # Add this line
}
self.count_series_data = {
'timestamp': [],
'source_count': [],
'target_count': []
}
self.delete_count = []
self.query_count = []
self.search_count = []
self.entities = []
self.delete_entities = []
self.upstream_data = {
'timestamp': [],
'delete_qps': [],
'query_qps': [],
'search_qps': [],
'delete_latency': [],
'query_latency': [],
'search_latency': []
}
def report_realtime_metrics(self):
current_time = time.time()
if self.last_report_time is None:
self.last_report_time = current_time
self.last_source_count = self.source_count
self.last_target_count = self.target_count
return
time_diff = current_time - self.last_report_time
insert_diff = self.source_count - self.last_source_count
sync_diff = self.target_count - self.last_target_count
insert_throughput = insert_diff / time_diff
sync_throughput = sync_diff / time_diff
avg_latency = sum(self.latencies[-100:]) / len(self.latencies[-100:]) if self.latencies else 0
logger.info(f"Real-time metrics:")
logger.info(f" Insert Throughput: {insert_throughput:.2f} entities/second")
logger.info(f" Sync Throughput: {sync_throughput:.2f} entities/second")
logger.info(f" Avg Latency (last 100): {avg_latency:.2f} seconds")
# Store time series data
self.time_series_data['timestamp'].append(current_time)
self.time_series_data['insert_throughput'].append(insert_throughput)
self.time_series_data['sync_throughput'].append(sync_throughput)
self.time_series_data['avg_latency'].append(avg_latency)
self.last_report_time = current_time
self.last_source_count = self.source_count
self.last_target_count = self.target_count
def continuous_monitoring(self, interval=5):
while not self.stop_query:
real_time_latency = self.latencies[-1] if self.latencies else 0
self.time_series_data['real_time_latency'].append(real_time_latency)
self.report_realtime_metrics()
time.sleep(interval)
def continuous_monitoring_upstream(self, window_size=60):
while not self.stop_query:
current_time = time.time()
window_start = current_time - window_size
# Filter operations within the time window
delete_window = deque([op for op in self.delete_count if op[1] > window_start])
query_window = deque([op for op in self.query_count if op[1] > window_start])
search_window = deque([op for op in self.search_count if op[1] > window_start])
# Calculate QPS
delete_qps = len(delete_window) / window_size
query_qps = len(query_window) / window_size
search_qps = len(search_window) / window_size
# Calculate average latency
delete_latency = mean([op[0] for op in delete_window]) if delete_window else 0
query_latency = mean([op[0] for op in query_window]) if query_window else 0
search_latency = mean([op[0] for op in search_window]) if search_window else 0
# Store the data
self.upstream_data['timestamp'].append(current_time)
self.upstream_data['delete_qps'].append(delete_qps)
self.upstream_data['query_qps'].append(query_qps)
self.upstream_data['search_qps'].append(search_qps)
self.upstream_data['delete_latency'].append(delete_latency)
self.upstream_data['query_latency'].append(query_latency)
self.upstream_data['search_latency'].append(search_latency)
# Log the results
logger.info(f"Upstream metrics for the last {window_size} seconds:")
logger.info(f"Delete - QPS: {delete_qps:.2f}, Avg Latency: {delete_latency:.4f}s")
logger.info(f"Query - QPS: {query_qps:.2f}, Avg Latency: {query_latency:.4f}s")
logger.info(f"Search - QPS: {search_qps:.2f}, Avg Latency: {search_latency:.4f}s")
# Update the operation counters to keep only the operations within the window
self.delete_count = list(delete_window)
self.query_count = list(query_window)
self.search_count = list(search_window)
# Wait for a short interval before the next analysis
time.sleep(5)
def plot_time_series_data(self):
df = pd.DataFrame(self.time_series_data)
count_df = pd.DataFrame(self.count_series_data)
upstream_df = pd.DataFrame(self.upstream_data)
fig = create_interactive_plot(df, count_df, upstream_df)
pio.write_html(fig, file='milvus_cdc_performance.html', auto_open=True)
logger.info("Interactive performance plot saved as 'milvus_cdc_performance.html'")
def list_cdc_tasks(self):
url = f"http://{self.cdc_host}:8444/cdc"
payload = json.dumps({"request_type": "list"})
response = requests.post(url, data=payload)
result = response.json()
logger.info(f"List CDC tasks response: {result}")
return result["data"]["tasks"]
def pause_cdc_tasks(self):
tasks = self.list_cdc_tasks()
for task in tasks:
task_id = task["task_id"]
url = f"http://{self.cdc_host}:8444/cdc"
payload = json.dumps({
"request_type": "pause",
"request_data": {"task_id": task_id}
})
response = requests.post(url, data=payload)
result = response.json()
logger.info(f"Pause CDC task {task_id} response: {result}")
self.cdc_paused = True
logger.info("All CDC tasks paused")
def resume_cdc_tasks(self):
tasks = self.list_cdc_tasks()
for task in tasks:
task_id = task["task_id"]
url = f"http://{self.cdc_host}:8444/cdc"
payload = json.dumps({
"request_type": "resume",
"request_data": {"task_id": task_id}
})
response = requests.post(url, data=payload)
result = response.json()
logger.info(f"Resume CDC task {task_id} response: {result}")
self.cdc_paused = False
logger.info("All CDC tasks resumed")
def pause_and_resume_cdc_tasks(self, duration):
time.sleep(duration / 3)
self.pause_cdc_tasks()
time.sleep(duration / 3)
self.resume_cdc_tasks()
def setup_collections(self):
self.resume_cdc_tasks()
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="timestamp", dtype=DataType.INT64),
FieldSchema(name="tag", dtype=DataType.INT64),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=128)
]
schema = CollectionSchema(fields, "Milvus CDC test collection")
c_name = "milvus_cdc_perf_test" + datetime.now().strftime("%Y%m%d%H%M%S")
# Create collections
self.source_collection = Collection(c_name, schema, using=self.source_alias, num_shards=4)
time.sleep(5)
self.target_collection = Collection(c_name, using=self.target_alias)
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 1024}
}
self.source_collection.create_index("vector", index_params)
self.source_collection.load()
time.sleep(1)
logger.info(f"source collection: {self.source_collection.describe()}")
logger.info(f"target collection: {self.target_collection.describe()}")
# insert data for delete
for i in range(10):
entities = self.generate_data(10000, tag=1)
entities[0] = [int(time.time() * 1000000) - i * 1000 for _ in range(10000)]
self.source_collection.insert(entities)
self.delete_entities_ts.extend(entities[0])
def generate_data(self, num_entities, tag=0):
return [
[int(time.time() * 1000) for _ in range(num_entities)], # timestamp
[tag for _ in range(num_entities)], # tag
[[random.random() for _ in range(128)] for _ in range(num_entities)] # vector
]
def continuous_insert(self, duration, batch_size):
end_time = time.time() + duration
while time.time() < end_time:
entities = self.generate_data(batch_size)
self.entities = entities
self.source_collection.insert(entities)
with (self.insert_lock):
self.insert_count += batch_size
self.latest_insert_status = {
"latest_ts": entities[0][-1],
"latest_count": self.insert_count
} # Update the latest insert timestamp
if random.random() < 0.1:
logger.debug(
f"insert_count: {self.insert_count}, latest_ts: {self.latest_insert_status['latest_ts']}")
time.sleep(0.01) # Small delay to prevent overwhelming the system
def perform_delete(self):
while not self.stop_query:
try:
if len(self.entities) < 1:
time.sleep(5)
continue
ts = self.delete_entities_ts.pop()
expr = f"timestamp == {ts} and tag == 1"
t0 = time.time()
res = self.source_collection.delete(expr)
tt= time.time() - t0
logger.debug(f"Delete result:{res} with expr {expr}, cost {tt} seconds")
self.delete_count.append([tt, time.time()])
except Exception as e:
logger.debug(f"Delete failed: {e}")
time.sleep(0.01)
def perform_query(self):
while not self.stop_query:
try:
if len(self.entities) < 1:
time.sleep(5)
continue
mid_expr = self.entities[0][500] - 5000
max_expr = int(time.time()*1000)
expr = f"timestamp >= {mid_expr} && timestamp <= {max_expr}"
t0 = time.time()
res = self.source_collection.query(expr)
tt = time.time() - t0
logger.debug(f"Query result: {len(res)} with expr {expr},cost time: {tt} seconds")
self.query_count.append([tt, time.time()])
except Exception as e:
logger.debug(f"Query failed: {e}")
def perform_search(self):
while not self.stop_query:
try:
if len(self.entities) < 1:
time.sleep(5)
continue
vectors = self.entities[2][100:101]
t0 = time.time()
res = self.source_collection.search(
data=vectors,
anns_field="vector",
param={"metric_type": "L2", "params": {"nprobe": 16}},
limit=1
)
tt = time.time() - t0
logger.debug(f"Search result: {res}, cost time: {tt} seconds")
self.search_count.append([tt, time.time()])
except Exception as e:
logger.debug(f"Search failed: {e}")
def continuous_query(self):
while not self.stop_query:
with self.insert_lock:
latest_insert_ts = self.latest_insert_status["latest_ts"]
latest_insert_count = self.latest_insert_status["latest_count"]
if latest_insert_ts > self.latest_query_ts:
try:
results = []
t0 = time.time()
while True and (time.time() - t0 < 10 * 60):
try:
results = self.target_collection.query(
expr=f"timestamp == {latest_insert_ts}",
output_fields=["timestamp"],
limit=1
)
except Exception as e:
logger.debug(f"Query failed: {e}")
if len(results) > 0 and results[0]["timestamp"] == latest_insert_ts:
logger.debug(
f"query latest_insert_ts: {latest_insert_ts}, results: {results} query latency: {time.time() - t0} seconds")
break
tt = time.time() - t0
# logger.info(f"start to query, latest_insert_ts: {latest_insert_ts}, results: {results}")
if len(results) > 0 and results[0]["timestamp"] == latest_insert_ts:
end_time = time.time()
latency = end_time - (latest_insert_ts / 1000) # Convert milliseconds to seconds
with self.sync_lock:
self.latest_query_ts = latest_insert_ts
self.latencies.append(latency)
logger.debug(
f"query latest_insert_ts: {latest_insert_ts}, results: {results} query latency: {latency} seconds")
except Exception as e:
logger.debug(f"Query failed: {e}")
time.sleep(0.01) # Query interval
def continuous_count(self):
def count_target():
try:
t0 = time.time()
results = self.target_collection.query(
expr="tag == 0",
output_fields=["count(*)"],
timeout=5
)
tt = time.time() - t0
self.target_count = results[0]['count(*)']
except Exception as e:
logger.error(f"Target count failed: {e}")
self.target_count = self.last_target_count
def count_source():
try:
t0 = time.time()
results = self.source_collection.query(
expr="tag == 0",
output_fields=["count(*)"],
timeout=5
)
tt = time.time() - t0
self.source_count = results[0]['count(*)']
except Exception as e:
logger.error(f"Source count failed: {e}")
self.source_count = self.last_source_count
previous_count = self.target_collection.query(
expr="",
output_fields=["count(*)"],
)[0]['count(*)']
while not self.stop_query:
try:
thread1 = threading.Thread(target=count_target)
thread2 = threading.Thread(target=count_source)
thread1.start()
thread2.start()
thread1.join()
thread2.join()
progress = (self.target_count / self.source_count) * 100 if self.source_count > 0 else 0
self.sync_count = self.target_count - previous_count
self.count_series_data['timestamp'].append(time.time())
self.count_series_data['source_count'].append(self.source_count)
self.count_series_data['target_count'].append(self.target_count)
logger.debug(f"sync progress {self.target_count}/{self.source_count} {progress:.2f}%")
except Exception as e:
logger.error(f"Count failed: {e}")
time.sleep(0.01) # Query interval
def measure_performance(self, duration, batch_size, concurrency):
self.insert_count = 0
self.sync_count = 0
self.latest_insert_ts = 0
self.latest_query_ts = int(time.time() * 1000)
self.latencies = []
self.stop_query = False
start_time = time.time()
# Start continuous query thread
query_thread = threading.Thread(target=self.continuous_query)
query_thread.start()
count_thread = threading.Thread(target=self.continuous_count)
count_thread.start()
cdc_thread = threading.Thread(target=self.pause_and_resume_cdc_tasks, args=(duration,))
cdc_thread.start()
upstream_op = [self.perform_query, self.perform_search, self.perform_delete]
upstream_op_task = []
for op in upstream_op:
task = threading.Thread(target=op)
task.start()
upstream_op_task.append(task)
monitor_thread = threading.Thread(target=self.continuous_monitoring)
monitor_thread.start()
monitoring_upstream_thread = threading.Thread(target=self.continuous_monitoring_upstream)
monitoring_upstream_thread.start()
# Start continuous insert threads
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futures = [executor.submit(self.continuous_insert, duration, batch_size) for _ in range(concurrency)]
# Wait for all insert operations to complete
for future in futures:
future.result()
self.stop_query = True
query_thread.join()
count_thread.join()
cdc_thread.join()
for task in upstream_op_task:
task.join()
monitor_thread.join()
end_time = time.time()
total_time = end_time - start_time
insert_throughput = self.insert_count / total_time
sync_throughput = self.sync_count / total_time
avg_latency = sum(self.latencies) / len(self.latencies) if self.latencies else 0
logger.info(f"Test duration: {total_time:.2f} seconds")
logger.info(f"Total inserted: {self.insert_count}")
logger.info(f"Total synced: {self.sync_count}")
logger.info(f"Insert throughput: {insert_throughput:.2f} entities/second")
logger.info(f"Sync throughput: {sync_throughput:.2f} entities/second")
logger.info(f"Average latency: {avg_latency:.2f} seconds")
logger.info(f"Min latency: {min(self.latencies):.2f} seconds")
logger.info(f"Max latency: {max(self.latencies):.2f} seconds")
self.plot_time_series_data()
return total_time, self.insert_count, self.sync_count, insert_throughput, sync_throughput, avg_latency, min(
self.latencies), max(self.latencies)
def test_scalability(self, max_duration=600, batch_size=1000, max_concurrency=10):
results = []
for concurrency in range(10, max_concurrency + 1, 10):
self.resume_cdc_tasks()
logger.info(f"\nTesting with concurrency: {concurrency}")
total_time, insert_count, sync_count, insert_throughput, sync_throughput, avg_latency, min_latency, max_latency = self.measure_performance(
max_duration, batch_size, concurrency)
results.append((concurrency, total_time, insert_count, sync_count, insert_throughput, sync_throughput,
avg_latency, min_latency, max_latency))
logger.info("\nScalability Test Results:")
for concurrency, total_time, insert_count, sync_count, insert_throughput, sync_throughput, avg_latency, min_latency, max_latency in results:
logger.info(f"Concurrency: {concurrency}")
logger.info(f" Insert Throughput: {insert_throughput:.2f} entities/second")
logger.info(f" Sync Throughput: {sync_throughput:.2f} entities/second")
logger.info(f" Avg Latency: {avg_latency:.2f} seconds")
logger.info(f" Min Latency: {min_latency:.2f} seconds")
logger.info(f" Max Latency: {max_latency:.2f} seconds")
return results
def run_all_tests(self, duration=300, batch_size=1000, max_concurrency=10):
logger.info("Starting Milvus CDC Performance Tests")
self.setup_collections()
self.test_scalability(duration, batch_size, max_concurrency)
logger.info("Milvus CDC Performance Tests Completed")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='cdc perf test')
parser.add_argument('--source_uri', type=str, default='http://10.104.18.188:19530', help='source uri')
parser.add_argument('--source_token', type=str, default='root:Milvus', help='source token')
parser.add_argument('--target_uri', type=str, default='http://10.104.4.87:19530', help='target uri')
parser.add_argument('--target_token', type=str, default='root:Milvus', help='target token')
parser.add_argument('--cdc_host', type=str, default='10.104.32.13', help='cdc host')
parser.add_argument('--test_duration', type=int, default=600, help='cdc test duration in seconds')
args = parser.parse_args()
connections.connect("source", uri=args.source_uri, token=args.source_token)
connections.connect("target", uri=args.target_uri, token=args.target_token)
cdc_test = MilvusCDCPerformance("source", "target", args.cdc_host)
cdc_test.run_all_tests(duration=args.test_duration, batch_size=1000, max_concurrency=10)