mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
add TopK
Former-commit-id: 847a46e3b7ab8bead610ec9888fce5a4279a6920
This commit is contained in:
parent
a6e92dc9ab
commit
cc641236d2
@ -6,7 +6,13 @@
|
||||
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
|
||||
project(vecwise_engine)
|
||||
project(vecwise_engine LANGUAGES CUDA CXX)
|
||||
|
||||
find_package(CUDA)
|
||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fPIC -std=c++11 -D_FORCE_INLINES -arch sm_60 --expt-extended-lambda")
|
||||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O0 -g")
|
||||
message("CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}")
|
||||
message("CUDA_NVCC_FLAGS=${CUDA_NVCC_FLAGS}")
|
||||
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ find_library(cuda_library cudart cublas HINTS /usr/local/cuda/lib64)
|
||||
|
||||
add_library(vecwise_engine STATIC ${vecwise_engine_src})
|
||||
|
||||
add_executable(vecwise_server
|
||||
cuda_add_executable(vecwise_server
|
||||
${config_files}
|
||||
${server_files}
|
||||
${utils_files}
|
||||
|
||||
67
cpp/src/wrapper/Arithmetic.h
Normal file
67
cpp/src/wrapper/Arithmetic.h
Normal file
@ -0,0 +1,67 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <cstddef>
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace engine {
|
||||
|
||||
using Bool = int8_t;
|
||||
using Byte = uint8_t;
|
||||
using Word = unsigned long;
|
||||
using EnumType = uint64_t;
|
||||
|
||||
using Float32 = float;
|
||||
using Float64 = double;
|
||||
|
||||
constexpr bool kBoolMax = std::numeric_limits<bool>::max();
|
||||
constexpr bool kBoolMin = std::numeric_limits<bool>::lowest();
|
||||
|
||||
constexpr int8_t kInt8Max = std::numeric_limits<int8_t>::max();
|
||||
constexpr int8_t kInt8Min = std::numeric_limits<int8_t>::lowest();
|
||||
|
||||
constexpr int16_t kInt16Max = std::numeric_limits<int16_t>::max();
|
||||
constexpr int16_t kInt16Min = std::numeric_limits<int16_t>::lowest();
|
||||
|
||||
constexpr int32_t kInt32Max = std::numeric_limits<int32_t>::max();
|
||||
constexpr int32_t kInt32Min = std::numeric_limits<int32_t>::lowest();
|
||||
|
||||
constexpr int64_t kInt64Max = std::numeric_limits<int64_t>::max();
|
||||
constexpr int64_t kInt64Min = std::numeric_limits<int64_t>::lowest();
|
||||
|
||||
constexpr float kFloatMax = std::numeric_limits<float>::max();
|
||||
constexpr float kFloatMin = std::numeric_limits<float>::lowest();
|
||||
|
||||
constexpr double kDoubleMax = std::numeric_limits<double>::max();
|
||||
constexpr double kDoubleMin = std::numeric_limits<double>::lowest();
|
||||
|
||||
constexpr uint32_t kFloat32DecimalPrecision = std::numeric_limits<Float32>::digits10;
|
||||
constexpr uint32_t kFloat64DecimalPrecision = std::numeric_limits<Float64>::digits10;
|
||||
|
||||
|
||||
constexpr uint8_t kByteWidth = 8;
|
||||
constexpr uint8_t kCharWidth = kByteWidth;
|
||||
constexpr uint8_t kWordWidth = sizeof(Word) * kByteWidth;
|
||||
constexpr uint8_t kEnumTypeWidth = sizeof(EnumType) * kByteWidth;
|
||||
|
||||
template<typename T>
|
||||
inline size_t
|
||||
WidthOf() { return sizeof(T) << 3; }
|
||||
|
||||
template<typename T>
|
||||
inline size_t
|
||||
WidthOf(const T &) { return sizeof(T) << 3; }
|
||||
|
||||
|
||||
}
|
||||
} // namespace lib
|
||||
} // namespace zilliz
|
||||
574
cpp/src/wrapper/Topk.cu
Normal file
574
cpp/src/wrapper/Topk.cu
Normal file
@ -0,0 +1,574 @@
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
// Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
// Proprietary and confidential.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "faiss/FaissAssert.h"
|
||||
#include "faiss/gpu/utils/Limits.cuh"
|
||||
#include "Arithmetic.h"
|
||||
|
||||
|
||||
namespace faiss {
|
||||
namespace gpu {
|
||||
|
||||
constexpr bool kBoolMax = zilliz::vecwise::engine::kBoolMax;
|
||||
constexpr bool kBoolMin = zilliz::vecwise::engine::kBoolMin;
|
||||
|
||||
template<>
|
||||
struct Limits<bool> {
|
||||
static __device__ __host__
|
||||
inline bool getMin() {
|
||||
return kBoolMin;
|
||||
}
|
||||
static __device__ __host__
|
||||
inline bool getMax() {
|
||||
return kBoolMax;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr int8_t kInt8Max = zilliz::vecwise::engine::kInt8Max;
|
||||
constexpr int8_t kInt8Min = zilliz::vecwise::engine::kInt8Min;
|
||||
|
||||
template<>
|
||||
struct Limits<int8_t> {
|
||||
static __device__ __host__
|
||||
inline int8_t getMin() {
|
||||
return kInt8Min;
|
||||
}
|
||||
static __device__ __host__
|
||||
inline int8_t getMax() {
|
||||
return kInt8Max;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr int16_t kInt16Max = zilliz::vecwise::engine::kInt16Max;
|
||||
constexpr int16_t kInt16Min = zilliz::vecwise::engine::kInt16Min;
|
||||
|
||||
template<>
|
||||
struct Limits<int16_t> {
|
||||
static __device__ __host__
|
||||
inline int16_t getMin() {
|
||||
return kInt16Min;
|
||||
}
|
||||
static __device__ __host__
|
||||
inline int16_t getMax() {
|
||||
return kInt16Max;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr int64_t kInt64Max = zilliz::vecwise::engine::kInt64Max;
|
||||
constexpr int64_t kInt64Min = zilliz::vecwise::engine::kInt64Min;
|
||||
|
||||
template<>
|
||||
struct Limits<int64_t> {
|
||||
static __device__ __host__
|
||||
inline int64_t getMin() {
|
||||
return kInt64Min;
|
||||
}
|
||||
static __device__ __host__
|
||||
inline int64_t getMax() {
|
||||
return kInt64Max;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr double kDoubleMax = zilliz::vecwise::engine::kDoubleMax;
|
||||
constexpr double kDoubleMin = zilliz::vecwise::engine::kDoubleMin;
|
||||
|
||||
template<>
|
||||
struct Limits<double> {
|
||||
static __device__ __host__
|
||||
inline double getMin() {
|
||||
return kDoubleMin;
|
||||
}
|
||||
static __device__ __host__
|
||||
inline double getMax() {
|
||||
return kDoubleMax;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#include "faiss/gpu/utils/DeviceUtils.h"
|
||||
#include "faiss/gpu/utils/MathOperators.cuh"
|
||||
#include "faiss/gpu/utils/Pair.cuh"
|
||||
#include "faiss/gpu/utils/Reductions.cuh"
|
||||
#include "faiss/gpu/utils/Select.cuh"
|
||||
#include "faiss/gpu/utils/Tensor.cuh"
|
||||
#include "faiss/gpu/utils/StaticUtils.h"
|
||||
|
||||
#include "Topk.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace engine {
|
||||
namespace gpu {
|
||||
|
||||
constexpr int kWarpSize = 32;
|
||||
|
||||
template<typename T, int Dim, bool InnerContig>
|
||||
using Tensor = faiss::gpu::Tensor<T, Dim, InnerContig>;
|
||||
|
||||
template<typename T, typename U>
|
||||
using Pair = faiss::gpu::Pair<T, U>;
|
||||
|
||||
|
||||
// select kernel for k == 1
|
||||
template<typename T, int kRowsPerBlock, int kBlockSize>
|
||||
__global__ void topkSelectMin1(Tensor<T, 2, true> productDistances,
|
||||
Tensor<T, 2, true> outDistances,
|
||||
Tensor<int64_t, 2, true> outIndices) {
|
||||
// Each block handles kRowsPerBlock rows of the distances (results)
|
||||
Pair<T, int64_t> threadMin[kRowsPerBlock];
|
||||
__shared__
|
||||
Pair<T, int64_t> blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)];
|
||||
|
||||
T distance[kRowsPerBlock];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kRowsPerBlock; ++i) {
|
||||
threadMin[i].k = faiss::gpu::Limits<T>::getMax();
|
||||
threadMin[i].v = -1;
|
||||
}
|
||||
|
||||
// blockIdx.x: which chunk of rows we are responsible for updating
|
||||
int rowStart = blockIdx.x * kRowsPerBlock;
|
||||
|
||||
// FIXME: if we have exact multiples, don't need this
|
||||
bool endRow = (blockIdx.x == gridDim.x - 1);
|
||||
|
||||
if (endRow) {
|
||||
if (productDistances.getSize(0) % kRowsPerBlock == 0) {
|
||||
endRow = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (endRow) {
|
||||
for (int row = rowStart; row < productDistances.getSize(0); ++row) {
|
||||
for (int col = threadIdx.x; col < productDistances.getSize(1);
|
||||
col += blockDim.x) {
|
||||
distance[0] = productDistances[row][col];
|
||||
|
||||
if (faiss::gpu::Math<T>::lt(distance[0], threadMin[0].k)) {
|
||||
threadMin[0].k = distance[0];
|
||||
threadMin[0].v = col;
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce within the block
|
||||
threadMin[0] =
|
||||
faiss::gpu::blockReduceAll<Pair<T, int64_t>, faiss::gpu::Min<Pair<T, int64_t> >, false, false>(
|
||||
threadMin[0], faiss::gpu::Min<Pair<T, int64_t> >(), blockMin);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
outDistances[row][0] = threadMin[0].k;
|
||||
outIndices[row][0] = threadMin[0].v;
|
||||
}
|
||||
|
||||
// so we can use the shared memory again
|
||||
__syncthreads();
|
||||
|
||||
threadMin[0].k = faiss::gpu::Limits<T>::getMax();
|
||||
threadMin[0].v = -1;
|
||||
}
|
||||
} else {
|
||||
for (int col = threadIdx.x; col < productDistances.getSize(1);
|
||||
col += blockDim.x) {
|
||||
|
||||
#pragma unroll
|
||||
for (int row = 0; row < kRowsPerBlock; ++row) {
|
||||
distance[row] = productDistances[rowStart + row][col];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int row = 0; row < kRowsPerBlock; ++row) {
|
||||
if (faiss::gpu::Math<T>::lt(distance[row], threadMin[row].k)) {
|
||||
threadMin[row].k = distance[row];
|
||||
threadMin[row].v = col;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce within the block
|
||||
faiss::gpu::blockReduceAll<kRowsPerBlock, Pair<T, int64_t>, faiss::gpu::Min<Pair<T, int64_t> >, false, false>(
|
||||
threadMin, faiss::gpu::Min<Pair<T, int64_t> >(), blockMin);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
#pragma unroll
|
||||
for (int row = 0; row < kRowsPerBlock; ++row) {
|
||||
outDistances[rowStart + row][0] = threadMin[row].k;
|
||||
outIndices[rowStart + row][0] = threadMin[row].v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// L2 + select kernel for k > 1, no re-use of ||c||^2
|
||||
template<typename T, int NumWarpQ, int NumThreadQ, int ThreadsPerBlock>
|
||||
__global__ void topkSelectMinK(Tensor<T, 2, true> productDistances,
|
||||
Tensor<T, 2, true> outDistances,
|
||||
Tensor<int64_t, 2, true> outIndices,
|
||||
int k, T initK) {
|
||||
// Each block handles a single row of the distances (results)
|
||||
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
|
||||
|
||||
__shared__
|
||||
T smemK[kNumWarps * NumWarpQ];
|
||||
__shared__
|
||||
int64_t smemV[kNumWarps * NumWarpQ];
|
||||
|
||||
faiss::gpu::BlockSelect<T, int64_t, false, faiss::gpu::Comparator<T>,
|
||||
NumWarpQ, NumThreadQ, ThreadsPerBlock>
|
||||
heap(initK, -1, smemK, smemV, k);
|
||||
|
||||
int row = blockIdx.x;
|
||||
|
||||
// Whole warps must participate in the selection
|
||||
int limit = faiss::gpu::utils::roundDown(productDistances.getSize(1), kWarpSize);
|
||||
int i = threadIdx.x;
|
||||
|
||||
for (; i < limit; i += blockDim.x) {
|
||||
T v = productDistances[row][i];
|
||||
heap.add(v, i);
|
||||
}
|
||||
|
||||
if (i < productDistances.getSize(1)) {
|
||||
T v = productDistances[row][i];
|
||||
heap.addThreadQ(v, i);
|
||||
}
|
||||
|
||||
heap.reduce();
|
||||
for (int i = threadIdx.x; i < k; i += blockDim.x) {
|
||||
outDistances[row][i] = smemK[i];
|
||||
outIndices[row][i] = smemV[i];
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: no TVec specialization
|
||||
template<typename T>
|
||||
void runTopKSelectMin(Tensor<T, 2, true> &productDistances,
|
||||
Tensor<T, 2, true> &outDistances,
|
||||
Tensor<int64_t, 2, true> &outIndices,
|
||||
int k,
|
||||
cudaStream_t stream) {
|
||||
FAISS_ASSERT(productDistances.getSize(0) == outDistances.getSize(0));
|
||||
FAISS_ASSERT(productDistances.getSize(0) == outIndices.getSize(0));
|
||||
FAISS_ASSERT(outDistances.getSize(1) == k);
|
||||
FAISS_ASSERT(outIndices.getSize(1) == k);
|
||||
FAISS_ASSERT(k <= 1024);
|
||||
|
||||
if (k == 1) {
|
||||
constexpr int kThreadsPerBlock = 256;
|
||||
constexpr int kRowsPerBlock = 8;
|
||||
|
||||
auto block = dim3(kThreadsPerBlock);
|
||||
auto grid = dim3(faiss::gpu::utils::divUp(outDistances.getSize(0), kRowsPerBlock));
|
||||
|
||||
topkSelectMin1<T, kRowsPerBlock, kThreadsPerBlock>
|
||||
<< < grid, block, 0, stream >> > (productDistances, outDistances, outIndices);
|
||||
} else {
|
||||
constexpr int kThreadsPerBlock = 128;
|
||||
|
||||
auto block = dim3(kThreadsPerBlock);
|
||||
auto grid = dim3(outDistances.getSize(0));
|
||||
|
||||
#define RUN_TOPK_SELECT_MIN(NUM_WARP_Q, NUM_THREAD_Q) \
|
||||
do { \
|
||||
topkSelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, kThreadsPerBlock> \
|
||||
<<<grid, block, 0, stream>>>(productDistances, \
|
||||
outDistances, outIndices, \
|
||||
k, faiss::gpu::Limits<T>::getMax()); \
|
||||
} while (0)
|
||||
|
||||
if (k <= 32) {
|
||||
RUN_TOPK_SELECT_MIN(32, 2);
|
||||
} else if (k <= 64) {
|
||||
RUN_TOPK_SELECT_MIN(64, 3);
|
||||
} else if (k <= 128) {
|
||||
RUN_TOPK_SELECT_MIN(128, 3);
|
||||
} else if (k <= 256) {
|
||||
RUN_TOPK_SELECT_MIN(256, 4);
|
||||
} else if (k <= 512) {
|
||||
RUN_TOPK_SELECT_MIN(512, 8);
|
||||
} else if (k <= 1024) {
|
||||
RUN_TOPK_SELECT_MIN(1024, 8);
|
||||
} else {
|
||||
FAISS_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
CUDA_TEST_ERROR();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
// select kernel for k == 1
|
||||
template<typename T, int kRowsPerBlock, int kBlockSize>
|
||||
__global__ void topkSelectMax1(Tensor<T, 2, true> productDistances,
|
||||
Tensor<T, 2, true> outDistances,
|
||||
Tensor<int64_t, 2, true> outIndices) {
|
||||
// Each block handles kRowsPerBlock rows of the distances (results)
|
||||
Pair<T, int64_t> threadMax[kRowsPerBlock];
|
||||
__shared__
|
||||
Pair<T, int64_t> blockMax[kRowsPerBlock * (kBlockSize / kWarpSize)];
|
||||
|
||||
T distance[kRowsPerBlock];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kRowsPerBlock; ++i) {
|
||||
threadMax[i].k = faiss::gpu::Limits<T>::getMin();
|
||||
threadMax[i].v = -1;
|
||||
}
|
||||
|
||||
// blockIdx.x: which chunk of rows we are responsible for updating
|
||||
int rowStart = blockIdx.x * kRowsPerBlock;
|
||||
|
||||
// FIXME: if we have exact multiples, don't need this
|
||||
bool endRow = (blockIdx.x == gridDim.x - 1);
|
||||
|
||||
if (endRow) {
|
||||
if (productDistances.getSize(0) % kRowsPerBlock == 0) {
|
||||
endRow = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (endRow) {
|
||||
for (int row = rowStart; row < productDistances.getSize(0); ++row) {
|
||||
for (int col = threadIdx.x; col < productDistances.getSize(1);
|
||||
col += blockDim.x) {
|
||||
distance[0] = productDistances[row][col];
|
||||
|
||||
if (faiss::gpu::Math<T>::gt(distance[0], threadMax[0].k)) {
|
||||
threadMax[0].k = distance[0];
|
||||
threadMax[0].v = col;
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce within the block
|
||||
threadMax[0] =
|
||||
faiss::gpu::blockReduceAll<Pair<T, int64_t>, faiss::gpu::Max<Pair<T, int64_t> >, false, false>(
|
||||
threadMax[0], faiss::gpu::Max<Pair<T, int64_t> >(), blockMax);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
outDistances[row][0] = threadMax[0].k;
|
||||
outIndices[row][0] = threadMax[0].v;
|
||||
}
|
||||
|
||||
// so we can use the shared memory again
|
||||
__syncthreads();
|
||||
|
||||
threadMax[0].k = faiss::gpu::Limits<T>::getMin();
|
||||
threadMax[0].v = -1;
|
||||
}
|
||||
} else {
|
||||
for (int col = threadIdx.x; col < productDistances.getSize(1);
|
||||
col += blockDim.x) {
|
||||
|
||||
#pragma unroll
|
||||
for (int row = 0; row < kRowsPerBlock; ++row) {
|
||||
distance[row] = productDistances[rowStart + row][col];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int row = 0; row < kRowsPerBlock; ++row) {
|
||||
if (faiss::gpu::Math<T>::gt(distance[row], threadMax[row].k)) {
|
||||
threadMax[row].k = distance[row];
|
||||
threadMax[row].v = col;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce within the block
|
||||
faiss::gpu::blockReduceAll<kRowsPerBlock, Pair<T, int64_t>, faiss::gpu::Max<Pair<T, int64_t> >, false, false>(
|
||||
threadMax, faiss::gpu::Max<Pair<T, int64_t> >(), blockMax);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
#pragma unroll
|
||||
for (int row = 0; row < kRowsPerBlock; ++row) {
|
||||
outDistances[rowStart + row][0] = threadMax[row].k;
|
||||
outIndices[rowStart + row][0] = threadMax[row].v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// L2 + select kernel for k > 1, no re-use of ||c||^2
|
||||
template<typename T, int NumWarpQ, int NumThreadQ, int ThreadsPerBlock>
|
||||
__global__ void topkSelectMaxK(Tensor<T, 2, true> productDistances,
|
||||
Tensor<T, 2, true> outDistances,
|
||||
Tensor<int64_t, 2, true> outIndices,
|
||||
int k, T initK) {
|
||||
// Each block handles a single row of the distances (results)
|
||||
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
|
||||
|
||||
__shared__
|
||||
T smemK[kNumWarps * NumWarpQ];
|
||||
__shared__
|
||||
int64_t smemV[kNumWarps * NumWarpQ];
|
||||
|
||||
faiss::gpu::BlockSelect<T, int64_t, true, faiss::gpu::Comparator<T>,
|
||||
NumWarpQ, NumThreadQ, ThreadsPerBlock>
|
||||
heap(initK, -1, smemK, smemV, k);
|
||||
|
||||
int row = blockIdx.x;
|
||||
|
||||
// Whole warps must participate in the selection
|
||||
int limit = faiss::gpu::utils::roundDown(productDistances.getSize(1), kWarpSize);
|
||||
int i = threadIdx.x;
|
||||
|
||||
for (; i < limit; i += blockDim.x) {
|
||||
T v = productDistances[row][i];
|
||||
heap.add(v, i);
|
||||
}
|
||||
|
||||
if (i < productDistances.getSize(1)) {
|
||||
T v = productDistances[row][i];
|
||||
heap.addThreadQ(v, i);
|
||||
}
|
||||
|
||||
heap.reduce();
|
||||
for (int i = threadIdx.x; i < k; i += blockDim.x) {
|
||||
outDistances[row][i] = smemK[i];
|
||||
outIndices[row][i] = smemV[i];
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: no TVec specialization
|
||||
template<typename T>
|
||||
void runTopKSelectMax(Tensor<T, 2, true> &productDistances,
|
||||
Tensor<T, 2, true> &outDistances,
|
||||
Tensor<int64_t, 2, true> &outIndices,
|
||||
int k,
|
||||
cudaStream_t stream) {
|
||||
FAISS_ASSERT(productDistances.getSize(0) == outDistances.getSize(0));
|
||||
FAISS_ASSERT(productDistances.getSize(0) == outIndices.getSize(0));
|
||||
FAISS_ASSERT(outDistances.getSize(1) == k);
|
||||
FAISS_ASSERT(outIndices.getSize(1) == k);
|
||||
FAISS_ASSERT(k <= 1024);
|
||||
|
||||
if (k == 1) {
|
||||
constexpr int kThreadsPerBlock = 256;
|
||||
constexpr int kRowsPerBlock = 8;
|
||||
|
||||
auto block = dim3(kThreadsPerBlock);
|
||||
auto grid = dim3(faiss::gpu::utils::divUp(outDistances.getSize(0), kRowsPerBlock));
|
||||
|
||||
topkSelectMax1<T, kRowsPerBlock, kThreadsPerBlock>
|
||||
<< < grid, block, 0, stream >> > (productDistances, outDistances, outIndices);
|
||||
} else {
|
||||
constexpr int kThreadsPerBlock = 128;
|
||||
|
||||
auto block = dim3(kThreadsPerBlock);
|
||||
auto grid = dim3(outDistances.getSize(0));
|
||||
|
||||
#define RUN_TOPK_SELECT_MAX(NUM_WARP_Q, NUM_THREAD_Q) \
|
||||
do { \
|
||||
topkSelectMaxK<T, NUM_WARP_Q, NUM_THREAD_Q, kThreadsPerBlock> \
|
||||
<<<grid, block, 0, stream>>>(productDistances, \
|
||||
outDistances, outIndices, \
|
||||
k, faiss::gpu::Limits<T>::getMin()); \
|
||||
} while (0)
|
||||
|
||||
if (k <= 32) {
|
||||
RUN_TOPK_SELECT_MAX(32, 2);
|
||||
} else if (k <= 64) {
|
||||
RUN_TOPK_SELECT_MAX(64, 3);
|
||||
} else if (k <= 128) {
|
||||
RUN_TOPK_SELECT_MAX(128, 3);
|
||||
} else if (k <= 256) {
|
||||
RUN_TOPK_SELECT_MAX(256, 4);
|
||||
} else if (k <= 512) {
|
||||
RUN_TOPK_SELECT_MAX(512, 8);
|
||||
} else if (k <= 1024) {
|
||||
RUN_TOPK_SELECT_MAX(1024, 8);
|
||||
} else {
|
||||
FAISS_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
CUDA_TEST_ERROR();
|
||||
}
|
||||
//////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
void runTopKSelect(Tensor<T, 2, true> &productDistances,
|
||||
Tensor<T, 2, true> &outDistances,
|
||||
Tensor<int64_t, 2, true> &outIndices,
|
||||
bool dir,
|
||||
int k,
|
||||
cudaStream_t stream) {
|
||||
if (dir) {
|
||||
runTopKSelectMax<T>(productDistances,
|
||||
outDistances,
|
||||
outIndices,
|
||||
k,
|
||||
stream);
|
||||
} else {
|
||||
runTopKSelectMin<T>(productDistances,
|
||||
outDistances,
|
||||
outIndices,
|
||||
k,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void TopK(T *input,
|
||||
int length,
|
||||
int k,
|
||||
T *output,
|
||||
int64_t *idx,
|
||||
// Ordering order_flag,
|
||||
cudaStream_t stream) {
|
||||
|
||||
// bool dir = (order_flag == Ordering::kAscending ? false : true);
|
||||
bool dir = 0;
|
||||
|
||||
Tensor<T, 2, true> t_input(input, {1, length});
|
||||
Tensor<T, 2, true> t_output(output, {1, k});
|
||||
Tensor<int64_t, 2, true> t_idx(idx, {1, k});
|
||||
|
||||
runTopKSelect<T>(t_input, t_output, t_idx, dir, k, stream);
|
||||
}
|
||||
|
||||
//INSTANTIATION_TOPK_2(bool);
|
||||
//INSTANTIATION_TOPK_2(int8_t);
|
||||
//INSTANTIATION_TOPK_2(int16_t);
|
||||
INSTANTIATION_TOPK_2(int32_t);
|
||||
//INSTANTIATION_TOPK_2(int64_t);
|
||||
INSTANTIATION_TOPK_2(float);
|
||||
//INSTANTIATION_TOPK_2(double);
|
||||
//INSTANTIATION_TOPK(TimeInterval);
|
||||
//INSTANTIATION_TOPK(Float128);
|
||||
//INSTANTIATION_TOPK(char);
|
||||
|
||||
}
|
||||
|
||||
void TopK(float *host_input,
|
||||
int length,
|
||||
int k,
|
||||
float *output,
|
||||
int64_t *indices) {
|
||||
float *device_input, *device_output;
|
||||
int64_t *ids;
|
||||
|
||||
cudaMalloc((void **) &device_input, sizeof(float) * length);
|
||||
cudaMalloc((void **) &device_output, sizeof(float) * k);
|
||||
cudaMalloc((void **) &ids, sizeof(int64_t) * k);
|
||||
|
||||
cudaMemcpy(device_input, host_input, sizeof(float) * length, cudaMemcpyHostToDevice);
|
||||
|
||||
gpu::TopK<float>(device_input, length, k, device_output, ids, nullptr);
|
||||
|
||||
cudaMemcpy(output, device_output, sizeof(float) * k, cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(indices, ids, sizeof(int64_t) * k, cudaMemcpyDeviceToHost);
|
||||
|
||||
cudaFree(device_input);
|
||||
cudaFree(device_output);
|
||||
cudaFree(ids);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
61
cpp/src/wrapper/Topk.h
Normal file
61
cpp/src/wrapper/Topk.h
Normal file
@ -0,0 +1,61 @@
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
// Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
// Proprietary and confidential.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace engine {
|
||||
namespace gpu {
|
||||
|
||||
template<typename T>
|
||||
void
|
||||
TopK(T *input,
|
||||
int length,
|
||||
int k,
|
||||
T *output,
|
||||
int64_t *indices,
|
||||
// Ordering order_flag,
|
||||
cudaStream_t stream = nullptr);
|
||||
|
||||
|
||||
#define INSTANTIATION_TOPK_2(T) \
|
||||
template void \
|
||||
TopK<T>(T *input, \
|
||||
int length, \
|
||||
int k, \
|
||||
T *output, \
|
||||
int64_t *indices, \
|
||||
cudaStream_t stream)
|
||||
// Ordering order_flag, \
|
||||
// cudaStream_t stream)
|
||||
|
||||
//extern INSTANTIATION_TOPK_2(int8_t);
|
||||
//extern INSTANTIATION_TOPK_2(int16_t);
|
||||
extern INSTANTIATION_TOPK_2(int32_t);
|
||||
//extern INSTANTIATION_TOPK_2(int64_t);
|
||||
extern INSTANTIATION_TOPK_2(float);
|
||||
//extern INSTANTIATION_TOPK_2(double);
|
||||
//extern INSTANTIATION_TOPK(TimeInterval);
|
||||
//extern INSTANTIATION_TOPK(Float128);
|
||||
|
||||
}
|
||||
|
||||
// User Interface.
|
||||
void TopK(float *input,
|
||||
int length,
|
||||
int k,
|
||||
float *output,
|
||||
int64_t *indices);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -24,3 +24,10 @@ set(faiss_libs
|
||||
cublas
|
||||
)
|
||||
target_link_libraries(wrapper_test ${unittest_libs} ${faiss_libs})
|
||||
|
||||
set(topk_test_src
|
||||
topk_test.cpp
|
||||
${CMAKE_SOURCE_DIR}/src/wrapper/topk.cu)
|
||||
|
||||
cuda_add_executable(topk_test ${topk_test_src})
|
||||
target_link_libraries(topk_test ${unittest_libs} ${faiss_libs})
|
||||
|
||||
89
cpp/unittest/faiss_wrapper/topk_test.cpp
Normal file
89
cpp/unittest/faiss_wrapper/topk_test.cpp
Normal file
@ -0,0 +1,89 @@
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
// Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
// Proprietary and confidential.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "wrapper/Topk.h"
|
||||
|
||||
|
||||
using namespace zilliz::vecwise::engine;
|
||||
|
||||
constexpr float threshhold = 0.00001;
|
||||
|
||||
template<typename T>
|
||||
void TopK_check(T *data,
|
||||
int length,
|
||||
int k,
|
||||
T *result) {
|
||||
|
||||
std::vector<T> arr(data, data + length);
|
||||
sort(arr.begin(), arr.end(), std::less<T>());
|
||||
|
||||
for (int i = 0; i < k; ++i) {
|
||||
ASSERT_TRUE(fabs(arr[i] - result[i]) < threshhold);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(wrapper_topk, Wrapper_Test) {
|
||||
int length = 100000;
|
||||
int k = 1000;
|
||||
|
||||
float *host_input, *host_output;
|
||||
int64_t *ids;
|
||||
|
||||
host_input = (float *) malloc(length * sizeof(float));
|
||||
host_output = (float *) malloc(k * sizeof(float));
|
||||
ids = (int64_t *) malloc(k * sizeof(int64_t));
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<> dis(-1.0, 1.0);
|
||||
for (int i = 0; i < length; ++i) {
|
||||
host_input[i] = 1.0 * dis(gen);
|
||||
}
|
||||
|
||||
TopK(host_input, length, k, host_output, ids);
|
||||
TopK_check(host_input, length, k, host_output);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void TopK_Test(T factor) {
|
||||
int length = 1000000; // data length
|
||||
int k = 100;
|
||||
|
||||
T *data, *out;
|
||||
int64_t *idx;
|
||||
cudaMallocManaged((void **) &data, sizeof(T) * length);
|
||||
cudaMallocManaged((void **) &out, sizeof(T) * k);
|
||||
cudaMallocManaged((void **) &idx, sizeof(int64_t) * k);
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<> dis(-1.0, 1.0);
|
||||
|
||||
for (int i = 0; i < length; i++) {
|
||||
data[i] = factor * dis(gen);
|
||||
}
|
||||
|
||||
cudaMemAdvise(data, sizeof(T) * length, cudaMemAdviseSetReadMostly, 0);
|
||||
|
||||
cudaMemPrefetchAsync(data, sizeof(T) * length, 0);
|
||||
|
||||
gpu::TopK<T>(data, length, k, out, idx, nullptr);
|
||||
TopK_check<T>(data, length, k, out);
|
||||
|
||||
// order_flag = Ordering::kDescending;
|
||||
// TopK<T>(data, length, k, out, idx, nullptr);
|
||||
// TopK_check<T>(data, length, k, out);
|
||||
|
||||
cudaFree(data);
|
||||
cudaFree(out);
|
||||
cudaFree(idx);
|
||||
}
|
||||
|
||||
TEST(topk_test, Wrapper_Test) {
|
||||
TopK_Test<float>(1.0);
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user