Skip to content

Commit bf153db

Browse files
committed
Add NetFromRange and NetFromIntervalRange helpers
When retrieving set elements it can be desired to format the result in the same way `nft` would, which is merging intervals to CIDR representations. To make this easier, introduce helper functions which allow for conversion of IP address ranges to CIDR networks. Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
1 parent 1db35da commit bf153db

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

util.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ package nftables
1616

1717
import (
1818
"encoding/binary"
19+
"errors"
1920
"net"
21+
"net/netip"
2022

2123
"github.com/google/nftables/binaryutil"
2224
"golang.org/x/sys/unix"
@@ -126,3 +128,48 @@ func NetInterval(cidr string) (net.IP, net.IP, error) {
126128

127129
return first, nextIP(last), nil
128130
}
131+
132+
// NetFromRange returns a CIDR IP network given a start and end address
133+
func NetFromRange(first net.IP, last net.IP) (*net.IPNet, error) {
134+
ip1 := net.IP(first)
135+
ip2 := net.IP(last)
136+
137+
maxLen := 32
138+
isIpv6 := ip1.To4() == nil
139+
140+
if isIpv6 && ip2.To4() != nil {
141+
return nil, errors.New("Cannot mix IPv4 and IPv6.")
142+
}
143+
144+
if isIpv6 {
145+
maxLen = 128
146+
}
147+
148+
for l := maxLen; l >= 0; l-- {
149+
cidrmask := net.CIDRMask(l, maxLen)
150+
ipmask := ip2.Mask(cidrmask)
151+
ipnet := net.IPNet{
152+
IP: ipmask,
153+
Mask: cidrmask,
154+
}
155+
156+
if ipnet.Contains(ip1) {
157+
return &ipnet, nil
158+
}
159+
}
160+
161+
return nil, errors.New("Failed to construct network from range.")
162+
}
163+
164+
// NetFromNetInterval returns a CIDR IP network given a start and end address as found in intervals.
165+
// This is similar to NetFromRange, but subtracts one address from the end of the range.
166+
func NetFromIntervalRange(first net.IP, last net.IP) (out *net.IPNet, err error) {
167+
ip2, ok := netip.AddrFromSlice(last)
168+
if !ok {
169+
return nil, errors.New("Failed to construct slice from network.")
170+
}
171+
172+
previous := ip2.Prev()
173+
174+
return NetFromRange(first, previous.AsSlice())
175+
}

util_test.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,135 @@ func TestNetInterval(t *testing.T) {
201201
})
202202
}
203203
}
204+
205+
func TestNetFromRange(t *testing.T) {
206+
tests := []struct {
207+
name string
208+
first string
209+
last string
210+
wantNet string
211+
wantErr bool
212+
}{
213+
{
214+
first: "0.0.0.1",
215+
last: "255.255.255.254",
216+
wantNet: "0.0.0.0/0",
217+
wantErr: false,
218+
},
219+
{
220+
first: "192.168.4.0",
221+
last: "192.168.4.255",
222+
wantNet: "192.168.4.0/24",
223+
wantErr: false,
224+
},
225+
{
226+
first: "192.0.2.17",
227+
last: "192.0.2.30",
228+
wantNet: "192.0.2.16/28",
229+
wantErr: false,
230+
},
231+
{
232+
first: "2001:db8:100::",
233+
last: "2001:db8:100:ffff:ffff:ffff:ffff:ffff",
234+
wantNet: "2001:db8:100::/48",
235+
wantErr: false,
236+
},
237+
{
238+
first: "2001:db8:100::",
239+
last: "192.0.2.30",
240+
wantNet: "",
241+
wantErr: true,
242+
},
243+
{
244+
first: "192.0.2.30",
245+
last: "2001:db8:100::",
246+
wantNet: "",
247+
wantErr: true,
248+
},
249+
}
250+
251+
for _, tt := range tests {
252+
t.Run(tt.first+"-"+tt.last, func(t *testing.T) {
253+
gotNet, err := NetFromRange(net.ParseIP(tt.first), net.ParseIP(tt.last))
254+
if (err != nil) != tt.wantErr {
255+
t.Errorf("NetFromRange() error = %v, wantErr = %v", err, tt.wantErr)
256+
}
257+
258+
if tt.wantNet == "" {
259+
return
260+
}
261+
262+
_, wantNetParsed, err := net.ParseCIDR(tt.wantNet)
263+
if err != nil {
264+
t.Fatalf("NetFromRange() error parsing test network = %v", err)
265+
}
266+
267+
if !reflect.DeepEqual(gotNet, wantNetParsed) {
268+
t.Errorf("NetFromRange() gotNet = %+v, wantNet = %+v", gotNet, wantNetParsed)
269+
}
270+
})
271+
}
272+
}
273+
274+
func TestNetFromIntervalRange(t *testing.T) {
275+
tests := []struct {
276+
name string
277+
first string
278+
last string
279+
wantNet string
280+
wantErr bool
281+
}{
282+
{
283+
first: "192.0.2.16",
284+
last: "192.0.2.32",
285+
wantNet: "192.0.2.16/28",
286+
wantErr: false,
287+
},
288+
{
289+
first: "2001:db8:100::",
290+
last: "2001:db8:101::",
291+
wantNet: "2001:db8:100::/48",
292+
wantErr: false,
293+
},
294+
{
295+
first: "2001:db8:a1:11::",
296+
last: "2001:db8:a1:12::",
297+
wantNet: "2001:db8:a1:11::/64",
298+
wantErr: false,
299+
},
300+
{
301+
first: "2001:db8:100::",
302+
last: "192.0.2.30",
303+
wantNet: "",
304+
wantErr: true,
305+
},
306+
{
307+
first: "192.0.2.30",
308+
last: "2001:db8:100::",
309+
wantNet: "",
310+
wantErr: true,
311+
},
312+
}
313+
314+
for _, tt := range tests {
315+
t.Run(tt.first+"-"+tt.last, func(t *testing.T) {
316+
gotNet, err := NetFromIntervalRange(net.ParseIP(tt.first), net.ParseIP(tt.last))
317+
if (err != nil) != tt.wantErr {
318+
t.Errorf("NetFromIntervalRange() error = %v, wantErr = %v", err, tt.wantErr)
319+
}
320+
321+
if tt.wantNet == "" {
322+
return
323+
}
324+
325+
_, wantNetParsed, err := net.ParseCIDR(tt.wantNet)
326+
if err != nil {
327+
t.Fatalf("NetFromIntervalRange() error parsing test network = %v", err)
328+
}
329+
330+
if !reflect.DeepEqual(gotNet, wantNetParsed) {
331+
t.Errorf("NetFromIntervalRange() gotNet = %+v, wantNet = %+v", gotNet, wantNetParsed)
332+
}
333+
})
334+
}
335+
}

0 commit comments

Comments
 (0)