Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/VecSim/vec_sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,18 @@ extern "C" void VecSim_Normalize(void *blob, size_t dim, VecSimType type) {
}
}

extern "C" size_t VecSimParams_GetQueryBlobSize(VecSimType type, size_t dim, VecSimMetric metric) {
// Assert all supported types are covered
assert(type == VecSimType_FLOAT32 || type == VecSimType_FLOAT64 ||
type == VecSimType_BFLOAT16 || type == VecSimType_FLOAT16 || type == VecSimType_INT8 ||
type == VecSimType_UINT8);
size_t blobSize = VecSimType_sizeof(type) * dim;
if (metric == VecSimMetric_Cosine && (type == VecSimType_INT8 || type == VecSimType_UINT8)) {
blobSize += sizeof(float); // For the norm
}
return blobSize;
}

extern "C" size_t VecSimIndex_IndexSize(VecSimIndex *index) { return index->indexSize(); }

extern "C" VecSimResolveCode VecSimIndex_ResolveParams(VecSimIndex *index, VecSimRawParam *rparams,
Expand Down
13 changes: 13 additions & 0 deletions src/VecSim/vec_sim.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,19 @@ double VecSimIndex_GetDistanceFrom_Unsafe(VecSimIndex *index, size_t label, cons
*/
void VecSim_Normalize(void *blob, size_t dim, VecSimType type);

/**
* @brief Returns the required blob size for a query vector that will be normalized.
*
* For INT8/UINT8 vectors with Cosine metric, VecSim_Normalize appends the norm (a float)
* at the end of the blob, so the required size is larger than just dim * sizeof(type).
*
* @param type vector element type.
* @param dim vector dimension.
* @param metric distance metric.
* @return required blob size in bytes.
*/
size_t VecSimParams_GetQueryBlobSize(VecSimType type, size_t dim, VecSimMetric metric);

/**
* @brief Return the number of vectors in the index.
* @param index the index whose size is requested.
Expand Down
26 changes: 25 additions & 1 deletion tests/unit/test_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,11 @@ class CommonTypeMetricTests : public testing::TestWithParam<std::tuple<VecSimTyp
template <typename algo_params>
void test_initial_size_estimation();

virtual void TearDown() { VecSimIndex_Free(index); }
virtual void TearDown() {
if (index) {
VecSimIndex_Free(index);
}
}

VecSimIndex *index;
};
Expand Down Expand Up @@ -883,6 +887,26 @@ TEST_P(CommonTypeMetricTests, TestInitialSizeEstimationHNSW) {
this->test_initial_size_estimation<HNSWParams>();
}

TEST_P(CommonTypeMetricTests, TestGetQueryBlobSize) {
// We don't need to create an index for this test, set to nullptr to avoid cleanup issues
this->index = nullptr;

size_t dim = 4;
VecSimType type = std::get<0>(GetParam());
VecSimMetric metric = std::get<1>(GetParam());

// Call the API function
size_t actual = VecSimParams_GetQueryBlobSize(type, dim, metric);

// Calculate expected blob size
size_t expected = dim * VecSimType_sizeof(type);
if (metric == VecSimMetric_Cosine && (type == VecSimType_INT8 || type == VecSimType_UINT8)) {
expected += sizeof(float); // For the norm
}

ASSERT_EQ(actual, expected);
}

class CommonTypeMetricTieredTests : public CommonTypeMetricTests {
protected:
virtual void TearDown() override {}
Expand Down
Loading