diff --git a/core/src/dog_segment/collection_c.h b/core/src/dog_segment/collection_c.h index 02ae2c854e..5d3c8aea2f 100644 --- a/core/src/dog_segment/collection_c.h +++ b/core/src/dog_segment/collection_c.h @@ -4,9 +4,11 @@ extern "C" { typedef void* CCollection; -CCollection NewCollection(const char* collection_name, const char* schema_conf); +CCollection +NewCollection(const char* collection_name, const char* schema_conf); -void DeleteCollection(CCollection collection); +void +DeleteCollection(CCollection collection); #ifdef __cplusplus } diff --git a/core/src/dog_segment/partition_c.cpp b/core/src/dog_segment/partition_c.cpp index 2bd95ed13e..9901082b94 100644 --- a/core/src/dog_segment/partition_c.cpp +++ b/core/src/dog_segment/partition_c.cpp @@ -17,7 +17,8 @@ NewPartition(CCollection collection, const char* partition_name) { return (void*)partition.release(); } -void DeletePartition(CPartition partition) { +void +DeletePartition(CPartition partition) { auto p = (milvus::dog_segment::Partition*)partition; // TODO: delete print diff --git a/core/src/dog_segment/partition_c.h b/core/src/dog_segment/partition_c.h index 96e86cac72..d1bfeead05 100644 --- a/core/src/dog_segment/partition_c.h +++ b/core/src/dog_segment/partition_c.h @@ -6,9 +6,11 @@ extern "C" { typedef void* CPartition; -CPartition NewPartition(CCollection collection, const char* partition_name); +CPartition +NewPartition(CCollection collection, const char* partition_name); -void DeletePartition(CPartition partition); +void +DeletePartition(CPartition partition); #ifdef __cplusplus } diff --git a/core/src/dog_segment/segment_c.cpp b/core/src/dog_segment/segment_c.cpp index c60173a195..6c54e5ed84 100644 --- a/core/src/dog_segment/segment_c.cpp +++ b/core/src/dog_segment/segment_c.cpp @@ -1,7 +1,10 @@ +#include + #include "SegmentBase.h" #include "segment_c.h" #include "Partition.h" + CSegmentBase NewSegment(CPartition partition, unsigned long segment_id) { auto p = (milvus::dog_segment::Partition*)partition; @@ -15,7 +18,9 @@ NewSegment(CPartition partition, unsigned long segment_id) { return (void*)segment.release(); } -void DeleteSegment(CSegmentBase segment) { + +void +DeleteSegment(CSegmentBase segment) { auto s = (milvus::dog_segment::SegmentBase*)segment; // TODO: delete print @@ -23,13 +28,15 @@ void DeleteSegment(CSegmentBase segment) { delete s; } -int Insert(CSegmentBase c_segment, - signed long int size, - const unsigned long* primary_keys, - const unsigned long int* timestamps, - void* raw_data, - int sizeof_per_row, - signed long int count) { + +int +Insert(CSegmentBase c_segment, + signed long int size, + const unsigned long* primary_keys, + const unsigned long* timestamps, + void* raw_data, + int sizeof_per_row, + signed long int count) { auto segment = (milvus::dog_segment::SegmentBase*)c_segment; milvus::dog_segment::DogDataChunk dataChunk{}; @@ -41,3 +48,34 @@ int Insert(CSegmentBase c_segment, return res.code(); } + +int +Delete(CSegmentBase c_segment, + long size, + const unsigned long* primary_keys, + const unsigned long* timestamps) { + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + + auto res = segment->Delete(size, primary_keys, timestamps); + return res.code(); +} + + +int +Search(CSegmentBase c_segment, + void* fake_query, + unsigned long timestamp, + long int* result_ids, + float* result_distances) { + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + milvus::dog_segment::QueryResult query_result; + + auto res = segment->Query(nullptr, timestamp, query_result); + + // result_ids and result_distances have been allocated memory in goLang, + // so we don't need to malloc here. + memcpy(result_ids, query_result.result_ids_.data(), query_result.row_num_ * sizeof(long int)); + memcpy(result_distances, query_result.result_distances_.data(), query_result.row_num_ * sizeof(float)); + + return res.code(); +} diff --git a/core/src/dog_segment/segment_c.h b/core/src/dog_segment/segment_c.h index 3ebe9c1bda..f022c3321c 100644 --- a/core/src/dog_segment/segment_c.h +++ b/core/src/dog_segment/segment_c.h @@ -6,17 +6,33 @@ extern "C" { typedef void* CSegmentBase; -CSegmentBase NewSegment(CPartition partition, unsigned long segment_id); +CSegmentBase +NewSegment(CPartition partition, unsigned long segment_id); -void DeleteSegment(CSegmentBase segment); +void +DeleteSegment(CSegmentBase segment); -int Insert(CSegmentBase c_segment, - signed long int size, - const unsigned long* primary_keys, - const unsigned long int* timestamps, - void* raw_data, - int sizeof_per_row, - signed long int count); +int +Insert(CSegmentBase c_segment, + signed long int size, + const unsigned long* primary_keys, + const unsigned long* timestamps, + void* raw_data, + int sizeof_per_row, + signed long int count); + +int +Delete(CSegmentBase c_segment, + long size, + const unsigned long* primary_keys, + const unsigned long* timestamps); + +int +Search(CSegmentBase c_segment, + void* fake_query, + unsigned long timestamp, + long int* result_ids, + float* result_distances); #ifdef __cplusplus } diff --git a/core/unittest/test_c_api.cpp b/core/unittest/test_c_api.cpp index 530a7cbefd..302936369e 100644 --- a/core/unittest/test_c_api.cpp +++ b/core/unittest/test_c_api.cpp @@ -6,6 +6,7 @@ #include "dog_segment/collection_c.h" #include "dog_segment/segment_c.h" + TEST(CApiTest, CollectionTest) { auto collection_name = "collection0"; auto schema_tmp_conf = "null_schema"; @@ -13,6 +14,7 @@ TEST(CApiTest, CollectionTest) { DeleteCollection(collection); } + TEST(CApiTest, PartitonTest) { auto collection_name = "collection0"; auto schema_tmp_conf = "null_schema"; @@ -23,6 +25,7 @@ TEST(CApiTest, PartitonTest) { DeletePartition(partition); } + TEST(CApiTest, SegmentTest) { auto collection_name = "collection0"; auto schema_tmp_conf = "null_schema"; @@ -72,3 +75,89 @@ TEST(CApiTest, InsertTest) { DeletePartition(partition); DeleteSegment(segment); } + + +TEST(CApiTest, DeleteTest) { + auto collection_name = "collection0"; + auto schema_tmp_conf = "null_schema"; + auto collection = NewCollection(collection_name, schema_tmp_conf); + auto partition_name = "partition0"; + auto partition = NewPartition(collection, partition_name); + auto segment = NewSegment(partition, 0); + + std::vector raw_data; + std::vector timestamps; + std::vector uids; + int N = 10000; + std::default_random_engine e(67); + for(int i = 0; i < N; ++i) { + uids.push_back(100000 + i); + timestamps.push_back(0); + // append vec + float vec[16]; + for(auto &x: vec) { + x = e() % 2000 * 0.001 - 1.0; + } + raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); + int age = e() % 100; + raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); + } + + auto line_sizeof = (sizeof(int) + sizeof(float) * 16); + + auto ins_res = Insert(segment, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); + assert(ins_res == 0); + + unsigned long delete_primary_keys[] = {100000, 100001, 100002}; + unsigned long delete_timestamps[] = {0, 0, 0}; + + auto del_res = Delete(segment, 1, delete_primary_keys, delete_timestamps); + assert(del_res == 0); + + DeleteCollection(collection); + DeletePartition(partition); + DeleteSegment(segment); +} + + +TEST(CApiTest, SearchTest) { + auto collection_name = "collection0"; + auto schema_tmp_conf = "null_schema"; + auto collection = NewCollection(collection_name, schema_tmp_conf); + auto partition_name = "partition0"; + auto partition = NewPartition(collection, partition_name); + auto segment = NewSegment(partition, 0); + + std::vector raw_data; + std::vector timestamps; + std::vector uids; + int N = 10000; + std::default_random_engine e(67); + for(int i = 0; i < N; ++i) { + uids.push_back(100000 + i); + timestamps.push_back(0); + // append vec + float vec[16]; + for(auto &x: vec) { + x = e() % 2000 * 0.001 - 1.0; + } + raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); + int age = e() % 100; + raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); + } + + auto line_sizeof = (sizeof(int) + sizeof(float) * 16); + + auto ins_res = Insert(segment, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); + assert(ins_res == 0); + + long result_ids; + float result_distances; + auto sea_res = Search(segment, nullptr, 0, &result_ids, &result_distances); + assert(sea_res == 0); + assert(result_ids == 104490); + + DeleteCollection(collection); + DeletePartition(partition); + DeleteSegment(segment); +} diff --git a/reader/index.go b/reader/index.go index 879991ac5a..9bc078bb7b 100644 --- a/reader/index.go +++ b/reader/index.go @@ -1,7 +1,7 @@ package reader import ( - schema2 "suvlim/pulsar/client-go/schema" + schema2 "github.com/czs007/suvlim/pulsar/client-go/schema" ) type IndexConfig struct {} diff --git a/reader/query_node.go b/reader/query_node.go index 04ee31dddc..0d333253c0 100644 --- a/reader/query_node.go +++ b/reader/query_node.go @@ -16,8 +16,8 @@ import "C" import ( "errors" "fmt" - "github.com/czs007/suvlim/pulsar" - "github.com/czs007/suvlim/pulsar/schema" + "github.com/czs007/suvlim/pulsar/client-go" + "github.com/czs007/suvlim/pulsar/client-go/schema" "sync" "time" ) @@ -40,13 +40,13 @@ type QueryNodeTimeSync struct { type QueryNode struct { QueryNodeId uint64 Collections []*Collection - messageClient pulsar.MessageClient + messageClient client_go.MessageClient queryNodeTimeSync *QueryNodeTimeSync buffer QueryNodeDataBuffer } func NewQueryNode(queryNodeId uint64, timeSync uint64) *QueryNode { - mc := pulsar.MessageClient{} + mc := client_go.MessageClient{} queryNodeTimeSync := &QueryNodeTimeSync { deleteTimeSync: timeSync, @@ -91,12 +91,13 @@ func (node *QueryNode) doQueryNode (wg *sync.WaitGroup) { } func (node *QueryNode) PrepareBatchMsg() { - node.messageClient.PrepareBatchMsg(pulsar.JobType(0)) + node.messageClient.PrepareBatchMsg(client_go.JobType(0)) } func (node *QueryNode) StartMessageClient() { topics := []string{"insert", "delete"} - node.messageClient.InitClient("pulsar://localhost:6650", topics) + // TODO: add consumerMsgSchema + node.messageClient.InitClient("pulsar://localhost:6650", topics, "") go node.messageClient.ReceiveMessage() } @@ -272,7 +273,7 @@ func (node *QueryNode) Delete(deleteMessages []*schema.DeleteMsg, wg *sync.WaitG // TODO: does all entities from a common batch are in the same segment? var targetSegment = node.GetSegmentByEntityId(entityIds[0]) - var result = SegmentDelete(targetSegment, collectionName, &entityIds, ×tamps) + var result = SegmentDelete(targetSegment, &entityIds, ×tamps) wg.Done() return publishResult(&result, clientId) @@ -322,8 +323,8 @@ func (node *QueryNode) Search(searchMessages []*schema.SearchMsg, wg *sync.WaitG return schema.Status{} } - var result = SegmentSearch(targetSegment, collectionName, queryString, ×tamps, &records) + var result = SegmentSearch(targetSegment, queryString, ×tamps, &records) wg.Done() - return publishResult(&result, clientId) + return publishSearchResult(result, clientId) } diff --git a/reader/result.go b/reader/result.go index 405ffef01b..de1fd24863 100644 --- a/reader/result.go +++ b/reader/result.go @@ -2,11 +2,16 @@ package reader import ( "fmt" - schema2 "suvlim/pulsar/client-go/schema" + schema2 "github.com/czs007/suvlim/pulsar/client-go/schema" ) type ResultEntityIds []int64 +type SearchResult struct { + ResultIds []int64 + ResultDistances []float32 +} + func getResultTopicByClientId(clientId int64) string { // TODO: Result topic? return "result-topic/partition-" + string(clientId) @@ -19,6 +24,13 @@ func publishResult(ids *ResultEntityIds, clientId int64) schema2.Status { return schema2.Status{Error_code: schema2.ErrorCode_SUCCESS} } +func publishSearchResult(searchResults *[]SearchResult, clientId int64) schema2.Status { + // TODO: Pulsar publish + var resultTopic = getResultTopicByClientId(clientId) + fmt.Println(resultTopic) + return schema2.Status{Error_code: schema2.ErrorCode_SUCCESS} +} + func publicStatistic(statisticTopic string) schema2.Status { // TODO: get statistic info // getStatisticInfo() diff --git a/reader/segment.go b/reader/segment.go index c57021b8d0..234256bdbc 100644 --- a/reader/segment.go +++ b/reader/segment.go @@ -13,7 +13,8 @@ package reader */ import "C" import ( - "github.com/czs007/suvlim/pulsar/schema" + "github.com/czs007/suvlim/pulsar/client-go/schema" + "unsafe" ) const SegmentLifetime = 20000 @@ -61,19 +62,75 @@ func (s *Segment) Close() { //////////////////////////////////////////////////////////////////////////// func SegmentInsert(segment *Segment, entityIds *[]int64, timestamps *[]uint64, dataChunk [][]*schema.FieldValue) ResultEntityIds { - // void* raw_data, - // int sizeof_per_row, - // signed long int count + // TODO: remove hard code schema + // auto schema_tmp = std::make_shared(); + // schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16); + // schema_tmp->AddField("age", DataType::INT32); + + /*C.Insert + int + Insert(CSegmentBase c_segment, + signed long int size, + const unsigned long* primary_keys, + const unsigned long* timestamps, + void* raw_data, + int sizeof_per_row, + signed long int count); + */ + + //msgCount := len(dataChunk) + //cEntityIds := (*C.ulong)(entityIds) + // + //// dataChunk to raw data + //var rawData []byte + //var i int + //for i = 0; i < msgCount; i++ { + // rawVector := dataChunk[i][0].VectorRecord.Records + // rawData = append(rawData, rawVector...) + //} return ResultEntityIds{} } -func SegmentDelete(segment *Segment, collectionName string, entityIds *[]int64, timestamps *[]uint64) ResultEntityIds { - // TODO: wrap cgo +func SegmentDelete(segment *Segment, entityIds *[]int64, timestamps *[]uint64) ResultEntityIds { + /*C.Delete + int + Delete(CSegmentBase c_segment, + long size, + const unsigned long* primary_keys, + const unsigned long* timestamps); + */ + size := len(*entityIds) + + // TODO: add query result status check + var _ = C.Delete(segment.SegmentPtr, C.long(size), (*C.ulong)(entityIds), (*C.ulong)(timestamps)) + return ResultEntityIds{} } -func SegmentSearch(segment *Segment, collectionName string, queryString string, timestamps *[]uint64, vectorRecord *[]schema.VectorRecord) ResultEntityIds { - // TODO: wrap cgo - return ResultEntityIds{} +func SegmentSearch(segment *Segment, queryString string, timestamps *[]uint64, vectorRecord *[]schema.VectorRecord) *[]SearchResult { + /*C.Search + int + Search(CSegmentBase c_segment, + void* fake_query, + unsigned long timestamp, + long int* result_ids, + float* result_distances); + */ + var results []SearchResult + + // TODO: get top-k's k from queryString + const TopK = 1 + + for timestamp := range *timestamps { + resultIds := make([]int64, TopK) + resultDistances := make([]float32, TopK) + + // TODO: add query result status check + var _ = C.Search(segment.SegmentPtr, unsafe.Pointer(nil), C.ulong(timestamp), (*C.long)(&resultIds[0]), (*C.float)(&resultDistances[0])) + + results = append(results, SearchResult{ResultIds: resultIds, ResultDistances: resultDistances}) + } + + return &results } diff --git a/reader/segment_test.go b/reader/segment_test.go index 475fc0787f..8465364c9a 100644 --- a/reader/segment_test.go +++ b/reader/segment_test.go @@ -15,15 +15,79 @@ func TestConstructorAndDestructor(t *testing.T) { node.DeleteCollection(collection) } -func TestSegmentInsert(t *testing.T) { +//func TestSegmentInsert(t *testing.T) { +// node := NewQueryNode(0, 0) +// var collection = node.NewCollection("collection0", "fake schema") +// var partition = collection.NewPartition("partition0") +// var segment = partition.NewSegment(0) +// +// const DIM = 4 +// const N = 3 +// +// var ids = [N]uint64{1, 2, 3} +// var timestamps = [N]uint64{0, 0, 0} +// +// var vec = [DIM]float32{1.1, 2.2, 3.3, 4.4} +// var rawData []int8 +// +// for i := 0; i <= N; i++ { +// for _, ele := range vec { +// rawData=append(rawData, int8(ele)) +// } +// rawData=append(rawData, int8(i)) +// } +// +// const sizeofPerRow = 4 + DIM * 4 +// var res = Insert(segment, N, (*C.ulong)(&ids[0]), (*C.ulong)(×tamps[0]), unsafe.Pointer(&rawData[0]), C.int(sizeofPerRow), C.long(N)) +// assert.Equal() +// +// partition.DeleteSegment(segment) +// collection.DeletePartition(partition) +// node.DeleteCollection(collection) +//} + +func TestSegmentDelete(t *testing.T) { node := NewQueryNode(0, 0) var collection = node.NewCollection("collection0", "fake schema") var partition = collection.NewPartition("partition0") var segment = partition.NewSegment(0) + ids :=[] int64{1, 2, 3} + timestamps :=[] uint64 {0, 0, 0} + SegmentDelete(segment, &ids, ×tamps) partition.DeleteSegment(segment) collection.DeletePartition(partition) node.DeleteCollection(collection) } + +//func TestSegmentSearch(t *testing.T) { +// node := NewQueryNode(0, 0) +// var collection = node.NewCollection("collection0", "fake schema") +// var partition = collection.NewPartition("partition0") +// var segment = partition.NewSegment(0) +// +// const DIM = 4 +// const N = 3 +// +// var ids = [N]uint64{1, 2, 3} +// var timestamps = [N]uint64{0, 0, 0} +// +// var vec = [DIM]float32{1.1, 2.2, 3.3, 4.4} +// var rawData []int8 +// +// for i := 0; i <= N; i++ { +// for _, ele := range vec { +// rawData=append(rawData, int8(ele)) +// } +// rawData=append(rawData, int8(i)) +// } +// +// const sizeofPerRow = 4 + DIM * 4 +// SegmentSearch(segment, "fake query string", ×tamps, nil) +// +// partition.DeleteSegment(segment) +// collection.DeletePartition(partition) +// node.DeleteCollection(collection) +//} diff --git a/storage/internal/minio/codec/codec.go b/storage/internal/minio/codec/codec.go index 60f777e365..284b91d009 100644 --- a/storage/internal/minio/codec/codec.go +++ b/storage/internal/minio/codec/codec.go @@ -9,7 +9,7 @@ func MvccEncode(key []byte, ts uint64, suffix string) ([]byte, error) { return []byte(string(key) + "_" + fmt.Sprintf("%016x", ^ts) + "_" + suffix), nil } -func MvccDecode(key string) (string, uint64, string, error) { +func MvccDecode(key []byte) (string, uint64, string, error) { if len(key) < 16 { return "", 0, "", errors.New("insufficient bytes to decode value") } @@ -18,7 +18,7 @@ func MvccDecode(key string) (string, uint64, string, error) { TSIndex := 0 undersCount := 0 for i := len(key) - 1; i > 0; i-- { - if key[i] == '_' { + if key[i] == byte('_') { undersCount++ if undersCount == 1 { suffixIndex = i + 1 @@ -34,13 +34,13 @@ func MvccDecode(key string) (string, uint64, string, error) { } var TS uint64 - _, err := fmt.Sscanf(key[TSIndex:suffixIndex-1], "%x", &TS) + _, err := fmt.Sscanf(string(key[TSIndex:suffixIndex-1]), "%x", &TS) TS = ^TS if err != nil { return "", 0, "", err } - return key[0 : TSIndex-1], TS, key[suffixIndex:], nil + return string(key[0 : TSIndex-1]), TS, string(key[suffixIndex:]), nil } func LogEncode(key []byte, ts uint64, channel int) []byte { diff --git a/storage/internal/minio/minio_store.go b/storage/internal/minio/minio_store.go index 6b6565a797..46f27fadea 100644 --- a/storage/internal/minio/minio_store.go +++ b/storage/internal/minio/minio_store.go @@ -53,7 +53,7 @@ func (s *minioDriver) put(ctx context.Context, key Key, value Value, timestamp T return err } - err = s.driver.Put(ctx, []byte(minioKey), value) + err = s.driver.Put(ctx, minioKey, value) return err } @@ -70,7 +70,7 @@ func (s *minioDriver) scanLE(ctx context.Context, key Key, timestamp Timestamp, var timestamps []Timestamp for _, key := range keys { - _, timestamp, _, _ := codec.MvccDecode(string(key)) + _, timestamp, _, _ := codec.MvccDecode(key) timestamps = append(timestamps, timestamp) } @@ -90,7 +90,7 @@ func (s *minioDriver) scanGE(ctx context.Context, key Key, timestamp Timestamp, var timestamps []Timestamp for _, key := range keys { - _, timestamp, _, _ := codec.MvccDecode(string(key)) + _, timestamp, _, _ := codec.MvccDecode(key) timestamps = append(timestamps, timestamp) } @@ -139,7 +139,7 @@ func (s *minioDriver) GetRow(ctx context.Context, key Key, timestamp Timestamp) return nil, err } - _, _, suffix, err := MvccDecode(string(keys[0])) + _, _, suffix, err := MvccDecode(keys[0]) if err != nil{ return nil, err } @@ -149,10 +149,10 @@ func (s *minioDriver) GetRow(ctx context.Context, key Key, timestamp Timestamp) return values[0], err } -func (s *minioDriver) GetRows(ctx context.Context, keys []Key, timestamp Timestamp) ([]Value, error){ +func (s *minioDriver) GetRows(ctx context.Context, keys []Key, timestamps []Timestamp) ([]Value, error){ var values []Value - for _, key := range keys{ - value, err := s.GetRow(ctx, key, timestamp) + for i, key := range keys{ + value, err := s.GetRow(ctx, key, timestamps[i]) if err!= nil{ return nil, err } @@ -169,7 +169,7 @@ func (s *minioDriver) PutRow(ctx context.Context, key Key, value Value, segment err = s.driver.Put(ctx, minioKey, value) return err } -func (s *minioDriver) PutRows(ctx context.Context, keys []Key, values []Value, segment string, timestamp Timestamp) error{ +func (s *minioDriver) PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error{ maxThread := 100 batchSize := 1 keysLength := len(keys) @@ -185,9 +185,9 @@ func (s *minioDriver) PutRows(ctx context.Context, keys []Key, values []Value, s } errCh := make(chan error) - f := func(ctx2 context.Context, keys2 []Key, values2 []Value, segments2 string, timestamp2 Timestamp) { + f := func(ctx2 context.Context, keys2 []Key, values2 []Value, segments2 []string, timestamps2 []Timestamp) { for i := 0; i < len(keys2); i++{ - err := s.PutRow(ctx2, keys2[i], values2[i], segments2, timestamp2) + err := s.PutRow(ctx2, keys2[i], values2[i], segments2[i], timestamps2[i]) errCh <- err } } @@ -198,7 +198,7 @@ func (s *minioDriver) PutRows(ctx context.Context, keys []Key, values []Value, s if len(keys) < end { end = len(keys) } - f(ctx, keys[start:end], values[start:end], segment, timestamp) + f(ctx, keys[start:end], values[start:end], segments[start:end], timestamps[start:end]) }() } @@ -210,6 +210,33 @@ func (s *minioDriver) PutRows(ctx context.Context, keys []Key, values []Value, s return nil } +func (s *minioDriver) GetSegments(ctx context.Context, key Key, timestamp Timestamp) ([]string, error){ + keyEnd, err := MvccEncode(key, timestamp, "") + if err != nil{ + return nil, err + } + keys, _, err := s.driver.Scan(ctx, append(key, byte('_')), keyEnd, -1,true) + if err != nil { + return nil, err + } + segmentsSet := map[string]bool{} + for _, key := range keys { + _, _, segment, err := MvccDecode(key) + if err != nil { + panic("must no error") + } + segmentsSet[segment] = true + } + + var segments []string + for k, v := range segmentsSet { + if v == true { + segments = append(segments, k) + } + } + return segments, err +} + func (s *minioDriver) DeleteRow(ctx context.Context, key Key, timestamp Timestamp) error{ minioKey, err := MvccEncode(key, timestamp, "delete") if err != nil{ @@ -220,7 +247,7 @@ func (s *minioDriver) DeleteRow(ctx context.Context, key Key, timestamp Timestam return err } -func (s *minioDriver) DeleteRows(ctx context.Context, keys []Key, timestamp Timestamp) error{ +func (s *minioDriver) DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error{ maxThread := 100 batchSize := 1 keysLength := len(keys) @@ -236,9 +263,9 @@ func (s *minioDriver) DeleteRows(ctx context.Context, keys []Key, timestamp Time } errCh := make(chan error) - f := func(ctx2 context.Context, keys2 []Key, timestamp2 Timestamp) { + f := func(ctx2 context.Context, keys2 []Key, timestamps2 []Timestamp) { for i := 0; i < len(keys2); i++{ - err := s.DeleteRow(ctx2, keys2[i], timestamp2) + err := s.DeleteRow(ctx2, keys2[i], timestamps2[i]) errCh <- err } } @@ -249,7 +276,7 @@ func (s *minioDriver) DeleteRows(ctx context.Context, keys []Key, timestamp Time if len(keys) < end { end = len(keys) } - f(ctx, keys[start:end], timestamp) + f(ctx, keys[start:end], timestamps[start:end]) }() } diff --git a/storage/internal/minio/minio_storeEngine.go b/storage/internal/minio/minio_storeEngine.go index 3bce5cbf39..620e72b66b 100644 --- a/storage/internal/minio/minio_storeEngine.go +++ b/storage/internal/minio/minio_storeEngine.go @@ -75,7 +75,7 @@ func (s *minioStore) Scan(ctx context.Context, keyStart Key, keyEnd Key, limit i } values = append(values, value) } - limitCount--; + limitCount-- if limitCount <= 0{ break } diff --git a/storage/internal/minio/minio_test.go b/storage/internal/minio/minio_test.go index 98507a40ae..6baa55436c 100644 --- a/storage/internal/minio/minio_test.go +++ b/storage/internal/minio/minio_test.go @@ -41,6 +41,28 @@ func TestMinioDriver_DeleteRow(t *testing.T){ assert.Nil(t, object2) } +func TestMinioDriver_GetSegments(t *testing.T) { + err = client.PutRow(ctx, []byte("seg"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1234567) + assert.Nil(t, err) + err = client.PutRow(ctx, []byte("seg"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 1235567) + assert.Nil(t, err) + err = client.PutRow(ctx, []byte("seg"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 1236567) + assert.Nil(t, err) + err = client.PutRow(ctx, []byte("seg2"), []byte("testkeybarorbar_1"), "SegmentC", 1236567) + assert.Nil(t, err) + + segements, err := client.GetSegments(ctx, []byte("seg"), 1237777) + assert.Nil(t, err) + assert.Equal(t, 2, len(segements)) + if segements[0] == "SegmentA" { + assert.Equal(t, "SegmentA", segements[0]) + assert.Equal(t, "SegmentB", segements[1]) + } else { + assert.Equal(t, "SegmentB", segements[0]) + assert.Equal(t, "SegmentA", segements[1]) + } +} + func TestMinioDriver_PutRowsAndGetRows(t *testing.T){ keys := [][]byte{[]byte("foo"), []byte("bar")} values := [][]byte{[]byte("The key is foo!"), []byte("The key is bar!")} diff --git a/storage/internal/tikv/tikv_store.go b/storage/internal/tikv/tikv_store.go index 5924418d05..63d65679bd 100644 --- a/storage/internal/tikv/tikv_store.go +++ b/storage/internal/tikv/tikv_store.go @@ -230,7 +230,7 @@ func (s *TikvStore) PutRow(ctx context.Context, key Key, value Value, segment st return s.put(ctx, key, value, timestamp, segment) } -func (s *TikvStore) PutRows(ctx context.Context, keys []Key, values []Value, segment string, timestamps []Timestamp) error { +func (s *TikvStore) PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error { if len(keys) != len(values) { return errors.New("the len of keys is not equal to the len of values") } @@ -240,7 +240,7 @@ func (s *TikvStore) PutRows(ctx context.Context, keys []Key, values []Value, seg encodedKeys := make([]Key, len(keys)) for i, key := range keys { - encodedKeys[i] = EncodeKey(key, timestamps[i], segment) + encodedKeys[i] = EncodeKey(key, timestamps[i], segments[i]) } return s.engine.BatchPut(ctx, encodedKeys, values) } @@ -249,11 +249,11 @@ func (s *TikvStore) DeleteRow(ctx context.Context, key Key, timestamp Timestamp) return s.put(ctx, key, Value{0x00}, timestamp, string(DeleteMark)) } -func (s *TikvStore) DeleteRows(ctx context.Context, keys []Key, timestamp Timestamp) error { +func (s *TikvStore) DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error { encodeKeys := make([]Key, len(keys)) values := make([]Value, len(keys)) for i, key := range keys { - encodeKeys[i] = EncodeKey(key, timestamp, string(DeleteMark)) + encodeKeys[i] = EncodeKey(key, timestamps[i], string(DeleteMark)) values[i] = Value{0x00} } return s.engine.BatchPut(ctx, encodeKeys, values) diff --git a/storage/pkg/types/types.go b/storage/pkg/types/types.go index 4ad33754a2..991a5e2014 100644 --- a/storage/pkg/types/types.go +++ b/storage/pkg/types/types.go @@ -47,15 +47,15 @@ type Store interface { //deleteRange(ctx context.Context, key Key, start Timestamp, end Timestamp) error GetRow(ctx context.Context, key Key, timestamp Timestamp) (Value, error) - GetRows(ctx context.Context, keys []Key, timestamp Timestamp) ([]Value, error) + GetRows(ctx context.Context, keys []Key, timestamps []Timestamp) ([]Value, error) PutRow(ctx context.Context, key Key, value Value, segment string, timestamp Timestamp) error - PutRows(ctx context.Context, keys []Key, values []Value, segment string, timestamp []Timestamp) error + PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error GetSegments(ctx context.Context, key Key, timestamp Timestamp) ([]string, error) DeleteRow(ctx context.Context, key Key, timestamp Timestamp) error - DeleteRows(ctx context.Context, keys []Key, timestamp []Timestamp) error + DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error PutLog(ctx context.Context, key Key, value Value, timestamp Timestamp, channel int) error GetLog(ctx context.Context, start Timestamp, end Timestamp, channels []int) ([]Value, error)