Skip to content

Commit 3fa3ddf

Browse files
authored
Improve: ChangeThreads APIs for GoLang (#566)
Closes #564
1 parent e8d550b commit 3fa3ddf

File tree

2 files changed

+99
-69
lines changed

2 files changed

+99
-69
lines changed

golang/lib.go

Lines changed: 75 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func (m Metric) CValue() C.usearch_metric_kind_t {
7575
case Sorensen:
7676
return C.usearch_metric_sorensen_k
7777
}
78-
return C.usearch_metric_l2sq_k
78+
return C.usearch_metric_l2sq_k
7979
}
8080

8181
// Quantization represents the type for different scalar kinds used in quantization.
@@ -158,7 +158,7 @@ func NewIndex(conf IndexConfig) (index *Index, err error) {
158158
options.expansion_add = expansion_add
159159
options.expansion_search = expansion_search
160160
options.multi = multi
161-
options.metric_kind = conf.Metric.CValue()
161+
options.metric_kind = conf.Metric.CValue()
162162

163163
// Map the quantization method
164164
switch conf.Quantization {
@@ -256,6 +256,26 @@ func (index *Index) ChangeExpansionSearch(val uint) error {
256256
return nil
257257
}
258258

259+
// ChangeThreadsAdd sets the threads limit for add
260+
func (index *Index) ChangeThreadsAdd(val uint) error {
261+
var errorMessage *C.char
262+
C.usearch_change_threads_add((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.size_t(val), (*C.usearch_error_t)(&errorMessage))
263+
if errorMessage != nil {
264+
return errors.New(C.GoString(errorMessage))
265+
}
266+
return nil
267+
}
268+
269+
// ChangeThreadsSearch sets the threads limit for search
270+
func (index *Index) ChangeThreadsSearch(val uint) error {
271+
var errorMessage *C.char
272+
C.usearch_change_threads_search((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), C.size_t(val), (*C.usearch_error_t)(&errorMessage))
273+
if errorMessage != nil {
274+
return errors.New(C.GoString(errorMessage))
275+
}
276+
return nil
277+
}
278+
259279
// Connectivity returns the connectivity parameter of the index.
260280
func (index *Index) Connectivity() (con uint, err error) {
261281
var errorMessage *C.char
@@ -375,7 +395,7 @@ func (index *Index) Get(key Key, count uint) (vectors []float32, err error) {
375395
panic("Index is uninitialized")
376396
}
377397

378-
vectors = make([]float32, index.config.Dimensions * count)
398+
vectors = make([]float32, index.config.Dimensions*count)
379399
var errorMessage *C.char
380400
found := uint(C.usearch_get((C.usearch_index_t)(unsafe.Pointer(index.opaque_handle)), (C.usearch_key_t)(key), (C.size_t)(count), unsafe.Pointer(&vectors[0]), C.usearch_scalar_f32_k, (*C.usearch_error_t)(&errorMessage)))
381401
if errorMessage != nil {
@@ -432,8 +452,8 @@ func (index *Index) Search(query []float32, limit uint) (keys []Key, distances [
432452

433453
// ExactSearch is a multithreaded exact nearest neighbors search
434454
func ExactSearch(dataset []float32, queries []float32, dataset_size uint, queries_size uint,
435-
dataset_stride uint, queries_stride uint, dims uint, metric Metric,
436-
count uint, threads uint, keys_stride uint, distances_stride uint) (keys []Key, distances []float32, err error) {
455+
dataset_stride uint, queries_stride uint, dims uint, metric Metric,
456+
count uint, threads uint, keys_stride uint, distances_stride uint) (keys []Key, distances []float32, err error) {
437457
if (len(dataset) % int(dims)) != 0 {
438458
return nil, nil, errors.New("Dataset length must be a multiple of the dimensions")
439459
}
@@ -444,9 +464,9 @@ func ExactSearch(dataset []float32, queries []float32, dataset_size uint, querie
444464
keys = make([]Key, count)
445465
distances = make([]float32, count)
446466
var errorMessage *C.char
447-
C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(dataset_size), C.size_t(dataset_stride), unsafe.Pointer(&queries[0]), C.size_t(queries_size), C.size_t(queries_stride),
448-
C.usearch_scalar_f32_k, C.size_t(dims), metric.CValue(), C.size_t(count), C.size_t(threads),
449-
(*C.usearch_key_t)(&keys[0]), C.size_t(keys_stride), (*C.usearch_distance_t)(&distances[0]), C.size_t(distances_stride), (*C.usearch_error_t)(&errorMessage))
467+
C.usearch_exact_search(unsafe.Pointer(&dataset[0]), C.size_t(dataset_size), C.size_t(dataset_stride), unsafe.Pointer(&queries[0]), C.size_t(queries_size), C.size_t(queries_stride),
468+
C.usearch_scalar_f32_k, C.size_t(dims), metric.CValue(), C.size_t(count), C.size_t(threads),
469+
(*C.usearch_key_t)(&keys[0]), C.size_t(keys_stride), (*C.usearch_distance_t)(&distances[0]), C.size_t(distances_stride), (*C.usearch_error_t)(&errorMessage))
450470
if errorMessage != nil {
451471
return nil, nil, errors.New(C.GoString(errorMessage))
452472
}
@@ -521,43 +541,43 @@ func MetadataBuffer(buf []byte, buffer_size uint) (c IndexConfig, err error) {
521541

522542
// Map the metric kind
523543
switch options.metric_kind {
524-
case C.usearch_metric_l2sq_k:
525-
c.Metric = L2sq
544+
case C.usearch_metric_l2sq_k:
545+
c.Metric = L2sq
526546
case C.usearch_metric_ip_k:
527-
c.Metric = InnerProduct
528-
case C.usearch_metric_cos_k:
529-
c.Metric = Cosine
530-
case C.usearch_metric_haversine_k:
531-
c.Metric = Haversine
532-
case C.usearch_metric_pearson_k:
533-
c.Metric = Pearson
534-
case C.usearch_metric_hamming_k:
535-
c.Metric = Hamming
536-
case C.usearch_metric_tanimoto_k:
537-
c.Metric = Tanimoto
538-
case C.usearch_metric_sorensen_k:
539-
c.Metric = Sorensen
540-
}
547+
c.Metric = InnerProduct
548+
case C.usearch_metric_cos_k:
549+
c.Metric = Cosine
550+
case C.usearch_metric_haversine_k:
551+
c.Metric = Haversine
552+
case C.usearch_metric_pearson_k:
553+
c.Metric = Pearson
554+
case C.usearch_metric_hamming_k:
555+
c.Metric = Hamming
556+
case C.usearch_metric_tanimoto_k:
557+
c.Metric = Tanimoto
558+
case C.usearch_metric_sorensen_k:
559+
c.Metric = Sorensen
560+
}
541561

542562
// Map the quantization method
543563
switch options.quantization {
544-
case C.usearch_scalar_f16_k:
545-
c.Quantization = F16
546-
case C.usearch_scalar_f32_k:
547-
c.Quantization = F32
564+
case C.usearch_scalar_f16_k:
565+
c.Quantization = F16
566+
case C.usearch_scalar_f32_k:
567+
c.Quantization = F32
548568
case C.usearch_scalar_f64_k:
549-
c.Quantization = F64
569+
c.Quantization = F64
550570
case C.usearch_scalar_i8_k:
551-
c.Quantization = I8
571+
c.Quantization = I8
552572
case C.usearch_scalar_b1_k:
553-
c.Quantization = B1
573+
c.Quantization = B1
554574
}
555575

556576
return c, nil
557577
}
558578

559579
// Metadata loads the metadata from a specified file.
560-
func Metadata(path string) (c IndexConfig, err error) {
580+
func Metadata(path string) (c IndexConfig, err error) {
561581

562582
c_path := C.CString(path)
563583
defer C.free(unsafe.Pointer(c_path))
@@ -578,36 +598,36 @@ func Metadata(path string) (c IndexConfig, err error) {
578598

579599
// Map the metric kind
580600
switch options.metric_kind {
581-
case C.usearch_metric_l2sq_k:
582-
c.Metric = L2sq
601+
case C.usearch_metric_l2sq_k:
602+
c.Metric = L2sq
583603
case C.usearch_metric_ip_k:
584-
c.Metric = InnerProduct
585-
case C.usearch_metric_cos_k:
586-
c.Metric = Cosine
587-
case C.usearch_metric_haversine_k:
588-
c.Metric = Haversine
589-
case C.usearch_metric_pearson_k:
590-
c.Metric = Pearson
591-
case C.usearch_metric_hamming_k:
592-
c.Metric = Hamming
593-
case C.usearch_metric_tanimoto_k:
594-
c.Metric = Tanimoto
595-
case C.usearch_metric_sorensen_k:
596-
c.Metric = Sorensen
597-
}
604+
c.Metric = InnerProduct
605+
case C.usearch_metric_cos_k:
606+
c.Metric = Cosine
607+
case C.usearch_metric_haversine_k:
608+
c.Metric = Haversine
609+
case C.usearch_metric_pearson_k:
610+
c.Metric = Pearson
611+
case C.usearch_metric_hamming_k:
612+
c.Metric = Hamming
613+
case C.usearch_metric_tanimoto_k:
614+
c.Metric = Tanimoto
615+
case C.usearch_metric_sorensen_k:
616+
c.Metric = Sorensen
617+
}
598618

599619
// Map the quantization method
600620
switch options.quantization {
601-
case C.usearch_scalar_f16_k:
602-
c.Quantization = F16
603-
case C.usearch_scalar_f32_k:
604-
c.Quantization = F32
621+
case C.usearch_scalar_f16_k:
622+
c.Quantization = F16
623+
case C.usearch_scalar_f32_k:
624+
c.Quantization = F32
605625
case C.usearch_scalar_f64_k:
606-
c.Quantization = F64
626+
c.Quantization = F64
607627
case C.usearch_scalar_i8_k:
608-
c.Quantization = I8
628+
c.Quantization = I8
609629
case C.usearch_scalar_b1_k:
610-
c.Quantization = B1
630+
c.Quantization = B1
611631
}
612632

613633
return c, nil

golang/lib_test.go

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ func TestUSearch(t *testing.T) {
7979
t.Fatalf("Failed to reserve capacity: %s", err)
8080
}
8181

82+
err = ind.ChangeThreadsAdd(10)
83+
if err != nil {
84+
t.Fatalf("Failed to change threads add: %s", err)
85+
}
86+
8287
vec := make([]float32, dim)
8388
vec[0] = 40.0
8489
vec[1] = 2.0
@@ -111,6 +116,11 @@ func TestUSearch(t *testing.T) {
111116
t.Fatalf("Failed to reserve capacity: %s", err)
112117
}
113118

119+
err = ind.ChangeThreadsSearch(10)
120+
if err != nil {
121+
t.Fatalf("Failed to change threads search: %s", err)
122+
}
123+
114124
vec := make([]float32, dim)
115125
vec[0] = 40.0
116126
vec[1] = 2.0
@@ -125,12 +135,12 @@ func TestUSearch(t *testing.T) {
125135
t.Fatalf("Failed to search: %s", err)
126136
}
127137

128-
const tolerance = 1e-2 // For example, this sets the tolerance to 0.01
138+
const tolerance = 1e-2 // For example, this sets the tolerance to 0.01
129139
if keys[0] != 42 || math.Abs(float64(distances[0])) > tolerance {
130140
t.Fatalf("Expected result 42 with distance 0, got key %d with distance %f", keys[0], distances[0])
131141
}
132142

133-
// TODO: Add exact search
143+
// TODO: Add exact search
134144
})
135145

136146
t.Run("Test Save and Load", func(t *testing.T) {
@@ -158,22 +168,22 @@ func TestUSearch(t *testing.T) {
158168
}
159169

160170
vec := make([]float32, dim)
161-
for i := uint(0); i < dim; i++ {
162-
vec[i] = float32(i) + 0.2
163-
err = ind.Add(uint64(i), vec)
164-
if err != nil {
165-
t.Fatalf("Failed to insert: %s", err)
166-
}
167-
}
171+
for i := uint(0); i < dim; i++ {
172+
vec[i] = float32(i) + 0.2
173+
err = ind.Add(uint64(i), vec)
174+
if err != nil {
175+
t.Fatalf("Failed to insert: %s", err)
176+
}
177+
}
168178

169179
ind_length, err := ind.Len()
170180
if err != nil {
171181
t.Fatalf("Failed to retrieve size: %s", err)
172182
}
173183

174-
// TODO: Add invalid save and loads?
175-
buffer_size := uint(1*1024*1024)
176-
buf := make([]byte, buffer_size)
184+
// TODO: Add invalid save and loads?
185+
buffer_size := uint(1 * 1024 * 1024)
186+
buf := make([]byte, buffer_size)
177187
err = ind.SaveBuffer(buf, buffer_size)
178188
if err != nil {
179189
t.Fatalf("Failed to save the index to a buffer: %s", err)
@@ -191,7 +201,7 @@ func TestUSearch(t *testing.T) {
191201
if ind_length != ind2_length {
192202
t.Fatalf("Loaded index length %d doesn't match original of %d ", ind2_length, ind_length)
193203
}
194-
// TODO: Check some values
204+
// TODO: Check some values
195205

196206
err = indView.ViewBuffer(buf, buffer_size)
197207
if err != nil {
@@ -214,6 +224,6 @@ func TestUSearch(t *testing.T) {
214224
t.Fatalf("Loaded metadata doesn't match the index metadata")
215225
}
216226

217-
// TODO: Check file save/load/metadata
227+
// TODO: Check file save/load/metadata
218228
})
219229
}

0 commit comments

Comments
 (0)