feat:add new gpu index:GPU_BRUTE_FORCE and limit gpu index metric type (#29590)

issue: https://github.com/milvus-io/milvus/issues/29230
this pr do these things:
1. add gpu brute force;
2. limit gpu index only support l2 / ip;

Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
This commit is contained in:
cqy123456 2024-01-05 15:24:48 +08:00 committed by GitHub
parent c8db36a63a
commit 22bb84fa9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 281 additions and 8 deletions

View File

@ -93,9 +93,16 @@ KnowhereInitGPUMemoryPool(const uint32_t init_size, const uint32_t max_size) {
if (init_size == 0 && max_size == 0) {
knowhere::KnowhereConfig::SetRaftMemPool();
return;
} else if (init_size > max_size) {
PanicInfo(ConfigInvalid,
"Error Gpu memory pool params: init_size {} can't not large "
"than max_size {}.",
init_size,
max_size);
} else {
knowhere::KnowhereConfig::SetRaftMemPool(size_t{init_size},
size_t{max_size});
}
knowhere::KnowhereConfig::SetRaftMemPool(size_t{init_size},
size_t{max_size});
}
int32_t

View File

@ -43,9 +43,10 @@ func (mgr *indexCheckerMgrImpl) GetChecker(indexType string) (IndexChecker, erro
}
func (mgr *indexCheckerMgrImpl) registerIndexChecker() {
mgr.checkers[IndexRaftIvfFlat] = newIVFBaseChecker()
mgr.checkers[IndexRaftIvfFlat] = newRaftIVFFlatChecker()
mgr.checkers[IndexRaftIvfPQ] = newRaftIVFPQChecker()
mgr.checkers[IndexRaftCagra] = newCagraChecker()
mgr.checkers[IndexRaftBruteForce] = newRaftBruteForceChecker()
mgr.checkers[IndexFaissIDMap] = newFlatChecker()
mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker()
mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker()

View File

@ -50,9 +50,10 @@ var (
BinIDMapMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE} // const
BinIvfMetrics = []string{metric.HAMMING, metric.JACCARD} // const
HnswMetrics = []string{metric.L2, metric.IP, metric.COSINE, metric.HAMMING, metric.JACCARD} // const
CagraMetrics = []string{metric.L2} // const
supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
RaftMetrics = []string{metric.L2, metric.IP}
CagraMetrics = []string{metric.L2} // const
supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const
supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const
)
const (

View File

@ -20,6 +20,7 @@ const (
IndexRaftIvfFlat IndexType = "GPU_IVF_FLAT"
IndexRaftIvfPQ IndexType = "GPU_IVF_PQ"
IndexRaftCagra IndexType = "GPU_CAGRA"
IndexRaftBruteForce IndexType = "GPU_BRUTE_FORCE"
IndexFaissIDMap IndexType = "FLAT" // no index is built.
IndexFaissIvfFlat IndexType = "IVF_FLAT"
IndexFaissIvfPQ IndexType = "IVF_PQ"

View File

@ -0,0 +1,22 @@
package indexparamcheck
import "fmt"
type raftBruteForceChecker struct {
floatVectorBaseChecker
}
// raftBrustForceChecker checks if a Brute_Force index can be built.
func (c raftBruteForceChecker) CheckTrain(params map[string]string) error {
if err := c.floatVectorBaseChecker.CheckTrain(params); err != nil {
return err
}
if !CheckStrByValues(params, Metric, RaftMetrics) {
return fmt.Errorf("metric type not found or not supported, supported: %v", RaftMetrics)
}
return nil
}
func newRaftBruteForceChecker() IndexChecker {
return &raftBruteForceChecker{}
}

View File

@ -0,0 +1,64 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/util/metric"
)
func Test_raftbfChecker_CheckTrain(t *testing.T) {
p1 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.SUBSTRUCTURE,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
Metric: metric.SUPERSTRUCTURE,
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{p1, true},
{p2, true},
{p3, false},
{p4, false},
{p5, false},
{p6, false},
{p7, false},
}
c := newRaftBruteForceChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -0,0 +1,23 @@
package indexparamcheck
import "fmt"
// raftIVFChecker checks if a RAFT_IVF_Flat index can be built.
type raftIVFFlatChecker struct {
ivfBaseChecker
}
// CheckTrain checks if ivf-flat index can be built with the specific index parameters.
func (c *raftIVFFlatChecker) CheckTrain(params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
return err
}
if !CheckStrByValues(params, Metric, RaftMetrics) {
return fmt.Errorf("metric type not found or not supported, supported: %v", RaftMetrics)
}
return nil
}
func newRaftIVFFlatChecker() IndexChecker {
return &raftIVFFlatChecker{}
}

View File

@ -0,0 +1,152 @@
package indexparamcheck
import (
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/metric"
)
func Test_raftIvfFlatChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.L2,
}
p1 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.L2,
}
p2 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.IP,
}
p3 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.COSINE,
}
p4 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.HAMMING,
}
p5 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.JACCARD,
}
p6 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.SUBSTRUCTURE,
}
p7 := map[string]string{
DIM: strconv.Itoa(128),
NLIST: strconv.Itoa(1024),
Metric: metric.SUPERSTRUCTURE,
}
cases := []struct {
params map[string]string
errIsNil bool
}{
{validParams, true},
{invalidIVFParamsMin(), false},
{invalidIVFParamsMax(), false},
{p1, true},
{p2, true},
{p3, false},
{p4, false},
{p5, false},
{p6, false},
{p7, false},
}
c := newRaftIVFFlatChecker()
for _, test := range cases {
err := c.CheckTrain(test.params)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}
func Test_raftIvfFlatChecker_CheckValidDataType(t *testing.T) {
cases := []struct {
dType schemapb.DataType
errIsNil bool
}{
{
dType: schemapb.DataType_Bool,
errIsNil: false,
},
{
dType: schemapb.DataType_Int8,
errIsNil: false,
},
{
dType: schemapb.DataType_Int16,
errIsNil: false,
},
{
dType: schemapb.DataType_Int32,
errIsNil: false,
},
{
dType: schemapb.DataType_Int64,
errIsNil: false,
},
{
dType: schemapb.DataType_Float,
errIsNil: false,
},
{
dType: schemapb.DataType_Double,
errIsNil: false,
},
{
dType: schemapb.DataType_String,
errIsNil: false,
},
{
dType: schemapb.DataType_VarChar,
errIsNil: false,
},
{
dType: schemapb.DataType_Array,
errIsNil: false,
},
{
dType: schemapb.DataType_JSON,
errIsNil: false,
},
{
dType: schemapb.DataType_FloatVector,
errIsNil: true,
},
{
dType: schemapb.DataType_BinaryVector,
errIsNil: false,
},
}
c := newRaftIVFFlatChecker()
for _, test := range cases {
err := c.CheckValidDataType(test.dType)
if test.errIsNil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
}
}

View File

@ -15,7 +15,9 @@ func (c *raftIVFPQChecker) CheckTrain(params map[string]string) error {
if err := c.ivfBaseChecker.CheckTrain(params); err != nil {
return err
}
if !CheckStrByValues(params, Metric, RaftMetrics) {
return fmt.Errorf("metric type not found or not supported, supported: %v", RaftMetrics)
}
return c.checkPQParams(params)
}

View File

@ -123,7 +123,7 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) {
{validParamsMzero, true},
{p1, true},
{p2, true},
{p3, true},
{p3, false},
{p4, false},
{p5, false},
{p6, false},