milvus/cpp/src/wrapper/VecImpl.h
wxyu 0712798938 MS-631 IVFSQ8H Index support
Former-commit-id: 21e17a20794e4fde31e79c4bbd4e26d46c79d886
2019-10-10 20:57:14 +08:00

149 lines
3.7 KiB
C++

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "VecIndex.h"
#include "knowhere/index/vector_index/VectorIndex.h"
#include <memory>
#include <utility>
namespace milvus {
namespace engine {
class VecIndexImpl : public VecIndex {
public:
explicit VecIndexImpl(std::shared_ptr<knowhere::VectorIndex> index, const IndexType& type)
: index_(std::move(index)), type(type) {
}
Status
BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const float* xt) override;
VecIndexPtr
CopyToGpu(const int64_t& device_id, const Config& cfg) override;
VecIndexPtr
CopyToCpu(const Config& cfg) override;
IndexType
GetType() override;
int64_t
Dimension() override;
int64_t
Count() override;
Status
Add(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg) override;
knowhere::BinarySet
Serialize() override;
Status
Load(const knowhere::BinarySet& index_binary) override;
VecIndexPtr
Clone() override;
int64_t
GetDeviceId() override;
Status
Search(const int64_t& nq, const float* xq, float* dist, int64_t* ids, const Config& cfg) override;
protected:
int64_t dim = 0;
IndexType type = IndexType::INVALID;
std::shared_ptr<knowhere::VectorIndex> index_ = nullptr;
};
class IVFMixIndex : public VecIndexImpl {
public:
explicit IVFMixIndex(std::shared_ptr<knowhere::VectorIndex> index, const IndexType& type)
: VecIndexImpl(std::move(index), type) {
}
Status
BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const float* xt) override;
Status
Load(const knowhere::BinarySet& index_binary) override;
};
class IVFHybridIndex : public IVFMixIndex {
public:
explicit IVFHybridIndex(std::shared_ptr<knowhere::VectorIndex> index, const IndexType& type)
: IVFMixIndex(std::move(index), type) {
}
knowhere::QuantizerPtr
LoadQuantizer(const Config& conf) override;
Status
SetQuantizer(const knowhere::QuantizerPtr& q) override;
Status
UnsetQuantizer() override;
Status
LoadData(const knowhere::QuantizerPtr& q, const Config& conf) override;
};
class BFIndex : public VecIndexImpl {
public:
explicit BFIndex(std::shared_ptr<knowhere::VectorIndex> index)
: VecIndexImpl(std::move(index), IndexType::FAISS_IDMAP) {
}
ErrorCode
Build(const Config& cfg);
float*
GetRawVectors();
Status
BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt,
const float* xt) override;
int64_t*
GetRawIds();
};
class ToIndexData : public cache::DataObj {
public:
explicit ToIndexData(int64_t size) : size_(size) {
}
int64_t
Size() override {
return size_;
}
private:
int64_t size_;
};
} // namespace engine
} // namespace milvus