Skip to content

Commit 4f1e736

Browse files
committed
Refactoring MapGetter with generic
1 parent ec640a7 commit 4f1e736

File tree

6 files changed

+58
-60
lines changed

6 files changed

+58
-60
lines changed

query/options.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,13 @@ func WhereBool[R record.Record](
141141
}
142142
}
143143

144-
func WhereMap[R record.Record](
145-
getter record.MapGetter[R],
144+
func WhereMap[R record.Record, K comparable, V any](
145+
getter record.MapGetter[R, K, V],
146146
condition where.ComparatorType,
147147
value ...any,
148148
) BuilderOption {
149149
return AddWhereOption[R]{
150-
Cmp: comparators.NewMapFieldComparator[R](condition, getter, value...),
150+
Cmp: comparators.NewMapFieldComparator[R, K, V](condition, getter, value...),
151151
}
152152
}
153153

record/getters.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ type (
1919
}
2020
BoolGetter[R Record] Getter[R, bool]
2121
ComparableGetter[R Record, T LessComparable] Getter[R, T]
22-
MapGetter[R Record] Getter[R, Map]
22+
MapGetter[R Record, K comparable, V any] Getter[R, Map[K, V]]
2323
SetGetter[R Record, T comparable] Getter[R, Set[T]]
2424
)
2525

26-
func (getter Getter[R, T]) GetForRecord(item R) T { return getter.Get(item) }
27-
func (getter BoolGetter[R]) GetForRecord(item R) bool { return getter.Get(item) }
28-
func (getter ComparableGetter[R, T]) GetForRecord(item R) T { return getter.Get(item) }
29-
func (getter MapGetter[R]) GetForRecord(item R) Map { return getter.Get(item) }
30-
func (getter SetGetter[R, T]) GetForRecord(item R) Set[T] { return getter.Get(item) }
26+
func (getter Getter[R, T]) GetForRecord(item R) T { return getter.Get(item) }
27+
func (getter BoolGetter[R]) GetForRecord(item R) bool { return getter.Get(item) }
28+
func (getter ComparableGetter[R, T]) GetForRecord(item R) T { return getter.Get(item) }
29+
func (getter MapGetter[R, K, V]) GetForRecord(item R) Map[K, V] { return getter.Get(item) }
30+
func (getter SetGetter[R, T]) GetForRecord(item R) Set[T] { return getter.Get(item) }
3131

3232
func (getter BoolGetter[R]) Less(a, b R) bool { return !getter.Get(a) && getter.Get(b) }
3333
func (getter ComparableGetter[R, T]) Less(a, b R) bool { return getter.Get(a) < getter.Get(b) }

record/map.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package record
22

3-
type MapValueComparator interface {
4-
Compare(value interface{}) (bool, error)
3+
type MapValueComparator[V any] interface {
4+
Compare(value V) (bool, error)
55
}
66

7-
type Map interface {
8-
HasKey(key interface{}) bool
9-
HasValue(check MapValueComparator) (bool, error)
7+
type Map[K comparable, V any] interface {
8+
HasKey(key K) bool
9+
HasValue(check MapValueComparator[V]) (bool, error)
1010
}

simd_test.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,11 @@ const (
7474

7575
type Counters map[CounterKey]uint32
7676

77-
func (c Counters) HasKey(key any) bool {
78-
counterKey, ok := key.(CounterKey)
79-
if !ok {
80-
return false
81-
}
82-
_, ok = c[counterKey]
77+
func (c Counters) HasKey(key CounterKey) bool {
78+
_, ok := c[key]
8379
return ok
8480
}
85-
func (c Counters) HasValue(check record.MapValueComparator) (bool, error) {
81+
func (c Counters) HasValue(check record.MapValueComparator[uint32]) (bool, error) {
8682
for _, item := range c {
8783
res, err := check.Compare(item)
8884
if nil != err {
@@ -97,8 +93,8 @@ func (c Counters) HasValue(check record.MapValueComparator) (bool, error) {
9793

9894
type HasCounterValueEqual uint32
9995

100-
func (c HasCounterValueEqual) Compare(item any) (bool, error) {
101-
return item.(uint32) == uint32(c), nil
96+
func (c HasCounterValueEqual) Compare(item uint32) (bool, error) {
97+
return item == uint32(c), nil
10298
}
10399

104100
type User struct {
@@ -142,9 +138,9 @@ var userTags = record.SetGetter[*User, Tag]{
142138
Get: func(item *User) record.Set[Tag] { return item.Tags },
143139
}
144140

145-
var userCounters = record.MapGetter[*User]{
141+
var userCounters = record.MapGetter[*User, CounterKey, uint32]{
146142
Field: userFields.New("counters"),
147-
Get: func(item *User) record.Map { return item.Counters },
143+
Get: func(item *User) record.Map[CounterKey, uint32] { return item.Counters },
148144
}
149145

150146
type byOnline struct {

where/comparators/comparators_test.go

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,11 @@ type enum16 uint16
3232

3333
type mp map[int]int
3434

35-
func (m mp) HasKey(key any) bool {
36-
intKey, ok := key.(int)
37-
if !ok {
38-
return false
39-
}
40-
_, ok = m[intKey]
35+
func (m mp) HasKey(key int) bool {
36+
_, ok := m[key]
4137
return ok
4238
}
43-
func (m mp) HasValue(check record.MapValueComparator) (bool, error) {
39+
func (m mp) HasValue(check record.MapValueComparator[int]) (bool, error) {
4440
for _, value := range m {
4541
res, err := check.Compare(value)
4642
if nil != err {
@@ -53,9 +49,9 @@ func (m mp) HasValue(check record.MapValueComparator) (bool, error) {
5349
return false, nil
5450
}
5551

56-
type mapValueComparator func(item any) (bool, error)
52+
type mapValueComparator func(item int) (bool, error)
5753

58-
func (e mapValueComparator) Compare(item any) (bool, error) {
54+
func (e mapValueComparator) Compare(item int) (bool, error) {
5955
return e(item)
6056
}
6157

@@ -103,9 +99,9 @@ var ifaceGetter = record.Getter[*user, any]{
10399
Get: func(item *user) any { return item.iface },
104100
}
105101

106-
var mapGetter = record.MapGetter[*user]{
102+
var mapGetter = record.MapGetter[*user, int, int]{
107103
Field: fields.New("map"),
108-
Get: func(item *user) record.Map { return item.mp },
104+
Get: func(item *user) record.Map[int, int] { return item.mp },
109105
}
110106

111107
var setGetter = record.SetGetter[*user, int]{
@@ -767,31 +763,31 @@ func TestComparators(t *testing.T) { //nolint:maintidx
767763
comparator: NewMapFieldComparator[*user](
768764
where.MapHasValue,
769765
mapGetter,
770-
mapValueComparator(func(item any) (bool, error) {
771-
return item.(int) == 8, nil
766+
mapValueComparator(func(item int) (bool, error) {
767+
return item == 8, nil
772768
}),
773769
),
774770
expectedResult: true,
775771
expectedCmp: where.MapHasValue,
776772
expectedField: "map",
777-
expectedValues: []any{mapValueComparator(func(item any) (bool, error) {
778-
return item.(int) == 8, nil
773+
expectedValues: []any{mapValueComparator(func(item int) (bool, error) {
774+
return item == 8, nil
779775
})},
780776
},
781777
{
782778
name: "MapHasValue 10",
783779
comparator: NewMapFieldComparator[*user](
784780
where.MapHasValue,
785781
mapGetter,
786-
mapValueComparator(func(item any) (bool, error) {
787-
return item.(int) == 10, nil
782+
mapValueComparator(func(item int) (bool, error) {
783+
return item == 10, nil
788784
}),
789785
),
790786
expectedResult: false,
791787
expectedCmp: where.MapHasValue,
792788
expectedField: "map",
793-
expectedValues: []any{mapValueComparator(func(item any) (bool, error) {
794-
return item.(int) == 10, nil
789+
expectedValues: []any{mapValueComparator(func(item int) (bool, error) {
790+
return item == 10, nil
795791
})},
796792
},
797793
{
@@ -807,15 +803,15 @@ func TestComparators(t *testing.T) { //nolint:maintidx
807803
name: "MapHasValue error",
808804
comparator: NewMapFieldComparator[*user](
809805
where.MapHasValue, mapGetter,
810-
mapValueComparator(func(item any) (bool, error) {
806+
mapValueComparator(func(item int) (bool, error) {
811807
return false, errors.New("comparator error")
812808
}),
813809
),
814810
expectedResult: false,
815811
expectedError: errors.New("comparator error"),
816812
expectedCmp: where.MapHasValue,
817813
expectedField: "map",
818-
expectedValues: []any{mapValueComparator(func(item any) (bool, error) {
814+
expectedValues: []any{mapValueComparator(func(item int) (bool, error) {
819815
return false, errors.New("comparator error")
820816
})},
821817
},

where/comparators/map.go

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,63 @@
11
package comparators
22

33
import (
4+
"fmt"
5+
46
"github.com/shamcode/simd/record"
57
"github.com/shamcode/simd/where"
68
)
79

8-
type MapFieldComparator[R record.Record] struct {
10+
type MapFieldComparator[R record.Record, K comparable, V any] struct {
911
Cmp where.ComparatorType
10-
Getter record.MapGetter[R]
12+
Getter record.MapGetter[R, K, V]
1113
Value []any
1214
}
1315

14-
func (fc MapFieldComparator[R]) GetType() where.ComparatorType {
16+
func (fc MapFieldComparator[R, K, V]) GetType() where.ComparatorType {
1517
return fc.Cmp
1618
}
1719

18-
func (fc MapFieldComparator[R]) GetField() record.Field {
20+
func (fc MapFieldComparator[R, K, V]) GetField() record.Field {
1921
return fc.Getter.Field
2022
}
2123

22-
func (fc MapFieldComparator[R]) CompareValue(value record.Map) (bool, error) {
24+
func (fc MapFieldComparator[R, K, V]) CompareValue(value record.Map[K, V]) (bool, error) {
2325
switch fc.Cmp { //nolint:exhaustive
2426
case where.MapHasValue:
25-
cmp, ok := fc.Value[0].(record.MapValueComparator)
27+
cmp, ok := fc.Value[0].(record.MapValueComparator[V])
2628
if !ok {
2729
return false, NewFailCastTypeError(fc.GetField(), fc.Cmp, fc.Value[0], "record.MapValueComparator")
2830
}
2931
return value.HasValue(cmp)
3032
case where.MapHasKey:
31-
return value.HasKey(fc.Value[0]), nil
33+
val, ok := fc.Value[0].(K)
34+
if !ok {
35+
return false, NewFailCastTypeError(fc.GetField(), fc.Cmp, fc.Value[0], fmt.Sprintf("%T", val))
36+
}
37+
return value.HasKey(val), nil
3238
default:
3339
return false, NewNotImplementComparatorError(fc.GetField(), fc.Cmp)
3440
}
3541
}
3642

37-
func (fc MapFieldComparator[R]) Compare(item R) (bool, error) {
43+
func (fc MapFieldComparator[R, K, V]) Compare(item R) (bool, error) {
3844
return fc.CompareValue(fc.Getter.Get(item))
3945
}
4046

41-
func (fc MapFieldComparator[R]) ValuesCount() int {
47+
func (fc MapFieldComparator[R, K, V]) ValuesCount() int {
4248
return len(fc.Value)
4349
}
4450

45-
func (fc MapFieldComparator[R]) ValueAt(index int) any {
51+
func (fc MapFieldComparator[R, K, V]) ValueAt(index int) any {
4652
return fc.Value[index]
4753
}
4854

49-
func NewMapFieldComparator[R record.Record](
55+
func NewMapFieldComparator[R record.Record, K comparable, V any](
5056
cmp where.ComparatorType,
51-
getter record.MapGetter[R],
57+
getter record.MapGetter[R, K, V],
5258
value ...any,
53-
) MapFieldComparator[R] {
54-
return MapFieldComparator[R]{
59+
) MapFieldComparator[R, K, V] {
60+
return MapFieldComparator[R, K, V]{
5561
Cmp: cmp,
5662
Getter: getter,
5763
Value: value,

0 commit comments

Comments
 (0)