Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions include/flatnav/distances/IPDistanceDispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace flatnav::distances {

template <typename T>
static float defaultInnerProduct(const T* x, const T* y, const size_t& dimension) {
static float default_inner_product(const T* x, const T* y, const size_t& dimension) {
float inner_product = 0;
for (size_t i = 0; i < dimension; i++) {
inner_product += x[i] * y[i];
Expand All @@ -18,7 +18,7 @@ static float defaultInnerProduct(const T* x, const T* y, const size_t& dimension
template <typename T>
struct InnerProductImpl {
static float computeDistance(const T* x, const T* y, const size_t& dimension) {
return defaultInnerProduct<T>(x, y, dimension);
return default_inner_product<T>(x, y, dimension);
}
};

Expand All @@ -28,67 +28,82 @@ struct InnerProductImpl<float> {
#if defined(USE_AVX512)
if (platformSupportsAvx512()) {
if (dimension % 16 == 0) {
return util::computeIP_Avx512(x, y, dimension);
return util::compute_ip_avx512(x, y, dimension);
}
if (dimension % 4 == 0) {
#if defined(USE_AVX)
return util::computeIP_Avx_4aligned(x, y, dimension);
return util::compute_ip_avx_4aligned(x, y, dimension);
#else
return util::computeIP_Sse4Aligned(x, y, dimension);
return util::compute_ip_sse_4aligned(x, y, dimension);
#endif
} else if (dimension > 16) {
return util::computeIP_SseWithResidual_16(x, y, dimension);
return util::compute_ip_sse_residual_16(x, y, dimension);
} else if (dimension > 4) {
return util::computeIP_SseWithResidual_4(x, y, dimension);
return util::compute_ip_sse_residual_4(x, y, dimension);
}
}
#endif

#if defined(USE_AVX)
if (platformSupportsAvx()) {
if (dimension % 16 == 0) {
return util::computeIP_Avx(x, y, dimension);
return util::compute_ip_avx(x, y, dimension);
}
if (dimension % 4 == 0) {
return util::computeIP_Avx_4aligned(x, y, dimension);
return util::compute_ip_avx_4aligned(x, y, dimension);
} else if (dimension > 16) {
return util::computeIP_SseWithResidual_16(x, y, dimension);
return util::compute_ip_sse_residual_16(x, y, dimension);
} else if (dimension > 4) {
return util::computeIP_SseWithResidual_4(x, y, dimension);
return util::compute_ip_sse_residual_4(x, y, dimension);
}
}
#endif

#if defined(USE_SSE)
if (dimension % 16 == 0) {
return util::computeIP_Sse(x, y, dimension);
return util::compute_ip_sse(x, y, dimension);
}
if (dimension % 4 == 0) {
return util::computeIP_Sse_4aligned(x, y, dimension);
return util::compute_ip_sse_4aligned(x, y, dimension);
} else if (dimension > 16) {
return util::computeIP_SseWithResidual_16(x, y, dimension);
return util::compute_ip_sse_residual_16(x, y, dimension);
} else if (dimension > 4) {
return util::computeIP_SseWithResidual_4(x, y, dimension);
return util::compute_ip_sse_residual_4(x, y, dimension);
}

#endif
return defaultInnerProduct<float>(x, y, dimension);
return default_inner_product<float>(x, y, dimension);
}
};

// TODO: Include SIMD optimized implementations for int8_t.
template <>
struct InnerProductImpl<int8_t> {
static float computeDistance(const int8_t* x, const int8_t* y, const size_t& dimension) {
return defaultInnerProduct<int8_t>(x, y, dimension);
#if defined(USE_AVX512)
if (platformSupportsAvx512()) {
return util::compute_ip_avx512_int8(x, y, dimension);
}
#endif

#if defined(USE_AVX)
if (platformSupportsAvx()) {
return util::compute_ip_avx2_int8(x, y, dimension);
}
#endif

#if defined(USE_SSE4_1)
return util::compute_ip_sse_int8(x, y, dimension);
#endif

return default_inner_product<int8_t>(x, y, dimension);
}
};

// TODO: Include SIMD optimized implementations for uint8_t.
template <>
struct InnerProductImpl<uint8_t> {
static float computeDistance(const uint8_t* x, const uint8_t* y, const size_t& dimension) {
return defaultInnerProduct<uint8_t>(x, y, dimension);
return default_inner_product<uint8_t>(x, y, dimension);
}
};

Expand All @@ -99,4 +114,4 @@ struct IPDistanceDispatcher {
}
};

} // namespace flatnav::distances
} // namespace flatnav::distances
62 changes: 34 additions & 28 deletions include/flatnav/distances/L2DistanceDispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace flatnav::distances {

template <typename T>
static float defaultSquaredL2(const T* x, const T* y, const size_t& dimension) {
static float default_squared_l2(const T* x, const T* y, const size_t& dimension) {
float squared_distance = 0;
for (size_t i = 0; i < dimension; i++) {
float difference = x[i] - y[i];
Expand All @@ -31,7 +31,7 @@ struct SquaredL2Impl {
* @return The squared L2 distance between the two arrays.
*/
static float computeDistance(const T* x, const T* y, const size_t& dimension) {
return defaultSquaredL2<T>(x, y, dimension);
return default_squared_l2<T>(x, y, dimension);
}
};

Expand All @@ -42,64 +42,70 @@ struct SquaredL2Impl<float> {
#if defined(USE_AVX512)
if (platformSupportsAvx512()) {
if (dimension % 16 == 0) {
return util::computeL2_Avx512(x, y, dimension);
return util::compute_l2_avx512(x, y, dimension);
}
if (dimension % 4 == 0) {
return util::computeL2_Sse4Aligned(x, y, dimension);
return util::compute_l2_sse_4aligned(x, y, dimension);
} else if (dimension > 16) {
return util::computeL2_SseWithResidual_16(x, y, dimension);
return util::compute_l2_sse_residual_16(x, y, dimension);
} else if (dimension > 4) {
return util::computeL2_SseWithResidual_4(x, y, dimension);
return util::compute_l2_sse_residual_4(x, y, dimension);
}
}
#endif

#if defined(USE_AVX)
if (platformSupportsAvx()) {
if (dimension % 16 == 0) {
return util::computeL2_Avx2(x, y, dimension);
return util::compute_l2_avx2(x, y, dimension);
}
if (dimension % 4 == 0) {
return util::computeL2_Sse4Aligned(x, y, dimension);
return util::compute_l2_sse_4aligned(x, y, dimension);
} else if (dimension > 16) {
return util::computeL2_SseWithResidual_16(x, y, dimension);
return util::compute_l2_sse_residual_16(x, y, dimension);
} else if (dimension > 4) {
return util::computeL2_SseWithResidual_4(x, y, dimension);
return util::compute_l2_sse_residual_4(x, y, dimension);
}
}
#endif

#if defined(USE_SSE)
if (dimension % 16 == 0) {
return util::computeL2_Sse(x, y, dimension);
return util::compute_l2_sse(x, y, dimension);
}
if (dimension % 4 == 0) {
return util::computeL2_Sse4Aligned(x, y, dimension);
return util::compute_l2_sse_4aligned(x, y, dimension);
} else if (dimension > 16) {
return util::computeL2_SseWithResidual_16(x, y, dimension);
return util::compute_l2_sse_residual_16(x, y, dimension);
} else if (dimension > 4) {
return util::computeL2_SseWithResidual_4(x, y, dimension);
return util::compute_l2_sse_residual_4(x, y, dimension);
}
#else
return defaultSquaredL2<float>(x, y, dimension);
return default_squared_l2<float>(x, y, dimension);
#endif
}
};

template <>
struct SquaredL2Impl<int8_t> {
static float computeDistance(const int8_t* x, const int8_t* y, const size_t& dimension) {
// #if defined(USE_AVX512BW) && defined(USE_AVX512VNNI)
// if (platformSupportsAvx512()) {
// return flatnav::util::computeL2_Avx512_int8(x, y, dimension);
// }
// #endif
#if defined(USE_SSE_4_1)
// This requires some advanced SSE4.1 instructions, such as _mm_cvtepi8_epi16
// Reference: https://doc.rust-lang.org/beta/core/arch/x86_64/fn._mm_cvtepi8_epi16.html
return flatnav::util::computeL2_Sse_int8(x, y, dimension);
#if defined(USE_AVX512)
if (platformSupportsAvx512()) {
return util::compute_l2_avx512_int8(x, y, dimension);
}
#endif
return defaultSquaredL2<int8_t>(x, y, dimension);

#if defined(USE_AVX)
if (platformSupportsAvx()) {
return util::compute_l2_avx2_int8(x, y, dimension);
}
#endif

#if defined(USE_SSE4_1)
return util::compute_l2_sse_int8(x, y, dimension);
#endif

return default_squared_l2<int8_t>(x, y, dimension);
}
};

Expand All @@ -109,12 +115,12 @@ struct SquaredL2Impl<uint8_t> {
#if defined(USE_AVX512)
if (platformSupportsAvx512()) {
if (dimension % 64 == 0) {
return util::computeL2_Avx512_Uint8(x, y, dimension);
return util::compute_l2_avx512_uint8(x, y, dimension);
}
}
#endif

return defaultSquaredL2<uint8_t>(x, y, dimension);
return default_squared_l2<uint8_t>(x, y, dimension);
}
};

Expand All @@ -125,4 +131,4 @@ struct L2DistanceDispatcher {
}
};

} // namespace flatnav::distances
} // namespace flatnav::distances
17 changes: 13 additions & 4 deletions include/flatnav/quantization/ScalarQuantizedDistance.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,21 @@ class ScalarQuantizedDistance
float distanceImpl(const void* x, const void* y,
bool asymmetric = false) const {
if (asymmetric) {
// x is float query, y is stored int8
// x is float query, y is stored int8.
// Cache the quantized query by pointer: during beam search the same
// float query is compared against thousands of stored int8 vectors,
// so we only need to quantize once per unique query pointer.
thread_local std::vector<int8_t> query_buf;
if (query_buf.size() != _dimension) {
query_buf.resize(_dimension);
thread_local const float* cached_ptr = nullptr;

const float* query_ptr = static_cast<const float*>(x);
if (query_ptr != cached_ptr) {
if (query_buf.size() != _dimension) {
query_buf.resize(_dimension);
}
quantize(query_ptr, query_buf.data());
cached_ptr = query_ptr;
}
quantize(static_cast<const float*>(x), query_buf.data());
return computeInt8Distance(query_buf.data(),
static_cast<const int8_t*>(y));
}
Expand Down
Loading
Loading