2024-05-27 18:15:43 +08:00

257 lines
5.5 KiB
Go

package streamrpc
import (
"context"
"io"
"sync"
"github.com/golang/protobuf/proto"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
)
type QueryStreamServer interface {
Send(*internalpb.RetrieveResults) error
Context() context.Context
}
type QueryStreamClient interface {
Recv() (*internalpb.RetrieveResults, error)
Context() context.Context
CloseSend() error
}
type ConcurrentQueryStreamServer struct {
server QueryStreamServer
mu sync.Mutex
}
func (s *ConcurrentQueryStreamServer) Send(result *internalpb.RetrieveResults) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.server.Send(result)
}
func (s *ConcurrentQueryStreamServer) Context() context.Context {
return s.server.Context()
}
func NewConcurrentQueryStreamServer(srv QueryStreamServer) *ConcurrentQueryStreamServer {
return &ConcurrentQueryStreamServer{
server: srv,
mu: sync.Mutex{},
}
}
type RetrieveResultCache struct {
result *internalpb.RetrieveResults
size int
cap int
}
func (c *RetrieveResultCache) Put(result *internalpb.RetrieveResults) {
if c.result == nil {
c.result = result
c.size = proto.Size(result)
return
}
c.merge(result)
}
func (c *RetrieveResultCache) Flush() *internalpb.RetrieveResults {
result := c.result
c.result = nil
c.size = 0
return result
}
func (c *RetrieveResultCache) Alloc(result *internalpb.RetrieveResults) bool {
return proto.Size(result)+c.size <= c.cap
}
func (c *RetrieveResultCache) IsFull() bool {
return c.size > c.cap
}
func (c *RetrieveResultCache) IsEmpty() bool {
return c.size == 0
}
func (c *RetrieveResultCache) merge(result *internalpb.RetrieveResults) {
switch result.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
c.result.GetIds().GetIntId().Data = append(c.result.GetIds().GetIntId().GetData(), result.GetIds().GetIntId().GetData()...)
case *schemapb.IDs_StrId:
c.result.GetIds().GetStrId().Data = append(c.result.GetIds().GetStrId().GetData(), result.GetIds().GetStrId().GetData()...)
}
c.result.AllRetrieveCount = c.result.AllRetrieveCount + result.AllRetrieveCount
c.result.CostAggregation = mergeCostAggregation(c.result.GetCostAggregation(), result.GetCostAggregation())
c.size = proto.Size(c.result)
}
func mergeCostAggregation(a *internalpb.CostAggregation, b *internalpb.CostAggregation) *internalpb.CostAggregation {
if a == nil {
return b
}
if b == nil {
return a
}
return &internalpb.CostAggregation{
ResponseTime: a.GetResponseTime() + b.GetResponseTime(),
ServiceTime: a.GetServiceTime() + b.GetServiceTime(),
TotalNQ: a.GetTotalNQ() + b.GetTotalNQ(),
TotalRelatedDataSize: a.GetTotalRelatedDataSize() + b.GetTotalRelatedDataSize(),
}
}
// Merge result by size and time.
type ResultCacheServer struct {
srv QueryStreamServer
cache *RetrieveResultCache
mu sync.Mutex
}
func NewResultCacheServer(srv QueryStreamServer, cap int) *ResultCacheServer {
return &ResultCacheServer{
srv: srv,
cache: &RetrieveResultCache{cap: cap},
}
}
func (s *ResultCacheServer) Send(result *internalpb.RetrieveResults) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.cache.Alloc(result) && !s.cache.IsEmpty() {
result := s.cache.Flush()
if err := s.srv.Send(result); err != nil {
return err
}
}
s.cache.Put(result)
if s.cache.IsFull() {
result := s.cache.Flush()
if err := s.srv.Send(result); err != nil {
return err
}
}
return nil
}
func (s *ResultCacheServer) Flush() error {
s.mu.Lock()
defer s.mu.Unlock()
result := s.cache.Flush()
if result == nil {
return nil
}
if err := s.srv.Send(result); err != nil {
return err
}
return nil
}
func (s *ResultCacheServer) Context() context.Context {
return s.srv.Context()
}
// TODO LOCAL SERVER AND CLIENT FOR STANDALONE
// ONLY FOR TEST
type LocalQueryServer struct {
grpc.ServerStream
resultCh chan *internalpb.RetrieveResults
ctx context.Context
finishOnce sync.Once
errCh chan error
mu sync.Mutex
}
func (s *LocalQueryServer) Send(result *internalpb.RetrieveResults) error {
select {
case <-s.ctx.Done():
return s.ctx.Err()
default:
s.resultCh <- result
return nil
}
}
func (s *LocalQueryServer) FinishError() error {
return <-s.errCh
}
func (s *LocalQueryServer) Context() context.Context {
return s.ctx
}
func (s *LocalQueryServer) FinishSend(err error) error {
s.finishOnce.Do(func() {
close(s.resultCh)
if err != nil {
s.errCh <- err
} else {
s.errCh <- io.EOF
}
})
return nil
}
type LocalQueryClient struct {
grpc.ClientStream
server *LocalQueryServer
resultCh chan *internalpb.RetrieveResults
ctx context.Context
}
func (s *LocalQueryClient) RecvMsg(m interface{}) error {
// TODO implement me
panic("implement me")
}
func (s *LocalQueryClient) Recv() (*internalpb.RetrieveResults, error) {
select {
case <-s.ctx.Done():
return nil, s.ctx.Err()
default:
result, ok := <-s.resultCh
if !ok {
return nil, s.server.FinishError()
}
return result, nil
}
}
func (s *LocalQueryClient) Context() context.Context {
return s.ctx
}
func (s *LocalQueryClient) CloseSend() error {
return nil
}
func (s *LocalQueryClient) CreateServer() *LocalQueryServer {
s.server = &LocalQueryServer{
resultCh: s.resultCh,
ctx: s.ctx,
mu: sync.Mutex{},
errCh: make(chan error, 1),
}
return s.server
}
func NewLocalQueryClient(ctx context.Context) *LocalQueryClient {
return &LocalQueryClient{
resultCh: make(chan *internalpb.RetrieveResults, 64),
ctx: ctx,
}
}