Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions cpp/src/Spaces/Euclidean.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,51 @@ static float L2SqrSIMD16Ext(const float *pVect1, const float *pVect2,
_mm_store_ps(TmpRes, sum);
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}

#elif defined(USE_NEON)

static float L2SqrSIMD16Ext(const float *pVect1, const float *pVect2,
const size_t qty) {
size_t qty16 = qty >> 4;
const float *pEnd1 = pVect1 + (qty16 << 4);

float32x4_t sum0 = vdupq_n_f32(0);
float32x4_t sum1 = vdupq_n_f32(0);
float32x4_t sum2 = vdupq_n_f32(0);
float32x4_t sum3 = vdupq_n_f32(0);

while (pVect1 < pEnd1) {
float32x4_t v1_0 = vld1q_f32(pVect1);
float32x4_t v2_0 = vld1q_f32(pVect2);
float32x4_t diff0 = vsubq_f32(v1_0, v2_0);
sum0 = vfmaq_f32(sum0, diff0, diff0);

float32x4_t v1_1 = vld1q_f32(pVect1 + 4);
float32x4_t v2_1 = vld1q_f32(pVect2 + 4);
float32x4_t diff1 = vsubq_f32(v1_1, v2_1);
sum1 = vfmaq_f32(sum1, diff1, diff1);

float32x4_t v1_2 = vld1q_f32(pVect1 + 8);
float32x4_t v2_2 = vld1q_f32(pVect2 + 8);
float32x4_t diff2 = vsubq_f32(v1_2, v2_2);
sum2 = vfmaq_f32(sum2, diff2, diff2);

float32x4_t v1_3 = vld1q_f32(pVect1 + 12);
float32x4_t v2_3 = vld1q_f32(pVect2 + 12);
float32x4_t diff3 = vsubq_f32(v1_3, v2_3);
sum3 = vfmaq_f32(sum3, diff3, diff3);

pVect1 += 16;
pVect2 += 16;
}

return vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)));
}

#endif

#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) || \
defined(USE_NEON)
static float L2SqrSIMD16ExtResiduals(const float *pVect1, const float *pVect2,
const size_t qty) {
size_t qty16 = qty >> 4 << 4;
Expand All @@ -189,7 +231,7 @@ static float L2SqrSIMD16ExtResiduals(const float *pVect1, const float *pVect2,
}
#endif

#ifdef USE_SSE
#if defined(USE_SSE)
static float L2SqrSIMD4Ext(const float *pVect1, const float *pVect2,
const size_t qty) {
float PORTABLE_ALIGN32 TmpRes[8];
Expand All @@ -212,6 +254,30 @@ static float L2SqrSIMD4Ext(const float *pVect1, const float *pVect2,
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}

#elif defined(USE_NEON)

static float L2SqrSIMD4Ext(const float *pVect1, const float *pVect2,
const size_t qty) {
size_t qty4 = qty >> 2;
const float *pEnd1 = pVect1 + (qty4 << 2);

float32x4_t sum = vdupq_n_f32(0);

while (pVect1 < pEnd1) {
float32x4_t v1 = vld1q_f32(pVect1);
pVect1 += 4;
float32x4_t v2 = vld1q_f32(pVect2);
pVect2 += 4;
float32x4_t diff = vsubq_f32(v1, v2);
sum = vfmaq_f32(sum, diff, diff);
}

return vaddvq_f32(sum);
}

#endif

#if defined(USE_SSE) || defined(USE_NEON)
static float L2SqrSIMD4ExtResiduals(const float *pVect1, const float *pVect2,
const size_t qty) {
size_t qty4 = qty >> 2 << 2;
Expand Down Expand Up @@ -276,7 +342,8 @@ template <>
EuclideanSpace<float, float>::EuclideanSpace(size_t dim)
: data_size_(dim * sizeof(float)), dim_(dim) {
fstdistfunc_ = L2Sqr<float, float>;
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) || \
defined(USE_NEON)
if (dim % 16 == 0)
fstdistfunc_ = L2SqrSIMD16Ext;
else if (dim % 4 == 0)
Expand Down
94 changes: 92 additions & 2 deletions cpp/src/Spaces/InnerProduct.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,56 @@ static float InnerProductSIMD4Ext(const float *pVect1, const float *pVect2,
return 1.0f - sum;
}

#elif defined(USE_NEON)

static float InnerProductSIMD4Ext(const float *pVect1, const float *pVect2,
const size_t qty) {
size_t qty16 = qty / 16;
size_t qty4 = qty / 4;

const float *pEnd1 = pVect1 + 16 * qty16;
const float *pEnd2 = pVect1 + 4 * qty4;

float32x4_t sum0 = vdupq_n_f32(0);
float32x4_t sum1 = vdupq_n_f32(0);
float32x4_t sum2 = vdupq_n_f32(0);
float32x4_t sum3 = vdupq_n_f32(0);

while (pVect1 < pEnd1) {
float32x4_t v1_0 = vld1q_f32(pVect1);
float32x4_t v2_0 = vld1q_f32(pVect2);
sum0 = vfmaq_f32(sum0, v1_0, v2_0);

float32x4_t v1_1 = vld1q_f32(pVect1 + 4);
float32x4_t v2_1 = vld1q_f32(pVect2 + 4);
sum1 = vfmaq_f32(sum1, v1_1, v2_1);

float32x4_t v1_2 = vld1q_f32(pVect1 + 8);
float32x4_t v2_2 = vld1q_f32(pVect2 + 8);
sum2 = vfmaq_f32(sum2, v1_2, v2_2);

float32x4_t v1_3 = vld1q_f32(pVect1 + 12);
float32x4_t v2_3 = vld1q_f32(pVect2 + 12);
sum3 = vfmaq_f32(sum3, v1_3, v2_3);

pVect1 += 16;
pVect2 += 16;
}

float32x4_t sum_prod =
vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3));

while (pVect1 < pEnd2) {
float32x4_t v1 = vld1q_f32(pVect1);
pVect1 += 4;
float32x4_t v2 = vld1q_f32(pVect2);
pVect2 += 4;
sum_prod = vfmaq_f32(sum_prod, v1, v2);
}

return 1.0f - vaddvq_f32(sum_prod);
}

#endif

#if defined(USE_AVX512)
Expand Down Expand Up @@ -294,9 +344,48 @@ static float InnerProductSIMD16Ext(const float *pVect1, const float *pVect2,
return 1.0f - sum;
}

#elif defined(USE_NEON)

static float InnerProductSIMD16Ext(const float *pVect1, const float *pVect2,
const size_t qty) {
size_t qty16 = qty / 16;
const float *pEnd1 = pVect1 + 16 * qty16;

float32x4_t sum0 = vdupq_n_f32(0);
float32x4_t sum1 = vdupq_n_f32(0);
float32x4_t sum2 = vdupq_n_f32(0);
float32x4_t sum3 = vdupq_n_f32(0);

while (pVect1 < pEnd1) {
float32x4_t v1_0 = vld1q_f32(pVect1);
float32x4_t v2_0 = vld1q_f32(pVect2);
sum0 = vfmaq_f32(sum0, v1_0, v2_0);

float32x4_t v1_1 = vld1q_f32(pVect1 + 4);
float32x4_t v2_1 = vld1q_f32(pVect2 + 4);
sum1 = vfmaq_f32(sum1, v1_1, v2_1);

float32x4_t v1_2 = vld1q_f32(pVect1 + 8);
float32x4_t v2_2 = vld1q_f32(pVect2 + 8);
sum2 = vfmaq_f32(sum2, v1_2, v2_2);

float32x4_t v1_3 = vld1q_f32(pVect1 + 12);
float32x4_t v2_3 = vld1q_f32(pVect2 + 12);
sum3 = vfmaq_f32(sum3, v1_3, v2_3);

pVect1 += 16;
pVect2 += 16;
}

float sum =
vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)));
return 1.0f - sum;
}

#endif

#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) || \
defined(USE_NEON)
static float InnerProductSIMD16ExtResiduals(const float *pVect1,
const float *pVect2,
const size_t qty) {
Expand Down Expand Up @@ -374,7 +463,8 @@ template <>
InnerProductSpace<float, float>::InnerProductSpace(size_t dim)
: data_size_(dim * sizeof(float)), dim_(dim) {
fstdistfunc_ = InnerProduct<float, float>;
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) || \
defined(USE_NEON)
if (dim % 16 == 0)
fstdistfunc_ = InnerProductSIMD16Ext;
else if (dim % 4 == 0)
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#endif
#endif
#endif
#ifdef __ARM_NEON
#define USE_NEON
#endif
#endif

#if defined(USE_AVX) || defined(USE_SSE)
Expand All @@ -55,6 +58,10 @@
#endif
#endif

#if defined(USE_NEON)
#include <arm_neon.h>
#endif

#include "StreamUtils.h"
#include "visited_list_pool.h"
#include <functional>
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["scikit-build-core >=0.4.3", "nanobind >=1.3.2"]
requires = ["scikit-build-core >=0.4.3", "nanobind >=1.3.2,<=2.9.2"]
build-backend = "scikit_build_core.build"

[project]
Expand Down
Loading