Skip to content

Commit ba5b671

Browse files
authored
Add GetRuleByHandle, ResetRule & ResetRules methods (#326)
This commit introduces the following methods: - GetRuleByHandle - ResetRule - ResetRules It also refactors GetRules and the deprecated GetRule methods to share a common getRules implementation.
1 parent 1148f1a commit ba5b671

File tree

2 files changed

+298
-5
lines changed

2 files changed

+298
-5
lines changed

nftables_test.go

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7560,3 +7560,229 @@ func TestFlushWithGenID(t *testing.T) {
75607560
t.Errorf("expected table to not exist, got: %v", table)
75617561
}
75627562
}
7563+
7564+
func TestGetRuleByHandle(t *testing.T) {
7565+
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
7566+
defer nftest.CleanupSystemConn(t, newNS)
7567+
defer conn.FlushRuleset()
7568+
7569+
table := conn.AddTable(&nftables.Table{
7570+
Name: "test-table",
7571+
Family: nftables.TableFamilyIPv4,
7572+
})
7573+
7574+
chain := conn.AddChain(&nftables.Chain{
7575+
Name: "test-chain",
7576+
Table: table,
7577+
})
7578+
7579+
for i := range 3 {
7580+
conn.AddRule(&nftables.Rule{
7581+
Table: table,
7582+
Chain: chain,
7583+
UserData: fmt.Appendf([]byte{}, "rule-%d", i+1),
7584+
Exprs: []expr.Any{
7585+
&expr.Verdict{
7586+
Kind: expr.VerdictAccept,
7587+
},
7588+
},
7589+
})
7590+
}
7591+
7592+
if err := conn.Flush(); err != nil {
7593+
t.Fatalf("failed to flush: %v", err)
7594+
}
7595+
7596+
rules, err := conn.GetRules(table, chain)
7597+
if err != nil {
7598+
t.Fatalf("GetRules failed: %v", err)
7599+
}
7600+
7601+
want := rules[1]
7602+
7603+
got, err := conn.GetRuleByHandle(table, chain, want.Handle)
7604+
if err != nil {
7605+
t.Fatalf("GetRuleByHandle failed: %v", err)
7606+
}
7607+
if !bytes.Equal(got.UserData, want.UserData) {
7608+
t.Fatalf("expected userdata %q, got %q", got.UserData, want.UserData)
7609+
}
7610+
}
7611+
7612+
func TestResetRule(t *testing.T) {
7613+
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
7614+
defer nftest.CleanupSystemConn(t, newNS)
7615+
defer conn.FlushRuleset()
7616+
7617+
table := conn.AddTable(&nftables.Table{
7618+
Name: "test-table",
7619+
Family: nftables.TableFamilyIPv4,
7620+
})
7621+
7622+
chain := conn.AddChain(&nftables.Chain{
7623+
Name: "test-chain",
7624+
Table: table,
7625+
})
7626+
7627+
tests := [...]struct {
7628+
Bytes uint64
7629+
Packets uint64
7630+
Reset bool
7631+
}{
7632+
{
7633+
Bytes: 1024,
7634+
Packets: 1,
7635+
Reset: false,
7636+
},
7637+
{
7638+
Bytes: 2048,
7639+
Packets: 2,
7640+
Reset: true,
7641+
},
7642+
{
7643+
Bytes: 4096,
7644+
Packets: 4,
7645+
Reset: false,
7646+
},
7647+
}
7648+
7649+
for _, tt := range tests {
7650+
conn.AddRule(&nftables.Rule{
7651+
Table: table,
7652+
Chain: chain,
7653+
Exprs: []expr.Any{
7654+
&expr.Counter{
7655+
Bytes: tt.Bytes,
7656+
Packets: tt.Packets,
7657+
},
7658+
&expr.Verdict{
7659+
Kind: expr.VerdictAccept,
7660+
},
7661+
},
7662+
})
7663+
}
7664+
7665+
if err := conn.Flush(); err != nil {
7666+
t.Fatalf("flush failed: %v", err)
7667+
}
7668+
7669+
rules, err := conn.GetRules(table, chain)
7670+
if err != nil {
7671+
t.Fatalf("GetRules failed: %v", err)
7672+
}
7673+
7674+
if len(rules) != len(tests) {
7675+
t.Fatalf("expected %d rules, got %d", len(tests), len(rules))
7676+
}
7677+
7678+
for i, r := range rules {
7679+
if !tests[i].Reset {
7680+
continue
7681+
}
7682+
_, err := conn.ResetRule(table, chain, r.Handle)
7683+
if err != nil {
7684+
t.Fatalf("ResetRule failed: %v", err)
7685+
}
7686+
}
7687+
7688+
rules, err = conn.GetRules(table, chain)
7689+
if err != nil {
7690+
t.Fatalf("GetRules failed: %v", err)
7691+
}
7692+
7693+
for i, r := range rules {
7694+
counter, ok := r.Exprs[0].(*expr.Counter)
7695+
if !ok {
7696+
t.Errorf("expected first expr to be Counter, got %T", r.Exprs[0])
7697+
}
7698+
7699+
if tests[i].Reset {
7700+
if counter.Bytes != 0 || counter.Packets != 0 {
7701+
t.Errorf(
7702+
"expected counter values to be reset to zero, got Bytes=%d, Packets=%d",
7703+
counter.Bytes,
7704+
counter.Packets,
7705+
)
7706+
}
7707+
} else {
7708+
// Making sure that only the selected rules were reset
7709+
if counter.Bytes != tests[i].Bytes || counter.Packets != tests[i].Packets {
7710+
t.Errorf(
7711+
"unexpected counter values: got Bytes=%d, Packets=%d, want Bytes=%d, Packets=%d",
7712+
counter.Bytes,
7713+
counter.Packets,
7714+
tests[i].Bytes,
7715+
tests[i].Packets)
7716+
}
7717+
}
7718+
}
7719+
}
7720+
7721+
func TestResetRules(t *testing.T) {
7722+
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
7723+
defer nftest.CleanupSystemConn(t, newNS)
7724+
defer conn.FlushRuleset()
7725+
7726+
table := conn.AddTable(&nftables.Table{
7727+
Name: "test-table",
7728+
Family: nftables.TableFamilyIPv4,
7729+
})
7730+
7731+
chain := conn.AddChain(&nftables.Chain{
7732+
Name: "test-chain",
7733+
Table: table,
7734+
})
7735+
7736+
for range 3 {
7737+
conn.AddRule(&nftables.Rule{
7738+
Table: table,
7739+
Chain: chain,
7740+
Exprs: []expr.Any{
7741+
&expr.Counter{
7742+
Bytes: 1,
7743+
Packets: 1,
7744+
},
7745+
&expr.Verdict{
7746+
Kind: expr.VerdictAccept,
7747+
},
7748+
},
7749+
})
7750+
}
7751+
7752+
if err := conn.Flush(); err != nil {
7753+
t.Fatalf("flush failed: %v", err)
7754+
}
7755+
7756+
rules, err := conn.GetRules(table, chain)
7757+
if err != nil {
7758+
t.Fatalf("GetRules failed: %v", err)
7759+
}
7760+
7761+
if len(rules) != 3 {
7762+
t.Fatalf("expected %d rules, got %d", 3, len(rules))
7763+
}
7764+
7765+
if _, err := conn.ResetRules(table, chain); err != nil {
7766+
t.Fatalf("ResetRules failed: %v", err)
7767+
}
7768+
7769+
rules, err = conn.GetRules(table, chain)
7770+
if err != nil {
7771+
t.Fatalf("GetRules failed: %v", err)
7772+
}
7773+
7774+
for _, r := range rules {
7775+
counter, ok := r.Exprs[0].(*expr.Counter)
7776+
if !ok {
7777+
t.Errorf("expected first expr to be Counter, got %T", r.Exprs[0])
7778+
}
7779+
7780+
if counter.Bytes != 0 || counter.Packets != 0 {
7781+
t.Errorf(
7782+
"expected counter values to be reset to zero, got Bytes=%d, Packets=%d",
7783+
counter.Bytes,
7784+
counter.Packets,
7785+
)
7786+
}
7787+
}
7788+
}

rule.go

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,31 +71,98 @@ type Rule struct {
7171

7272
// GetRule returns the rules in the specified table and chain.
7373
//
74-
// Deprecated: use GetRules instead.
74+
// Deprecated: use GetRuleByHandle instead.
7575
func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) {
7676
return cc.GetRules(t, c)
7777
}
7878

79+
// GetRuleByHandle returns the rule in the specified table and chain by its
80+
// handle.
81+
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule
82+
func (cc *Conn) GetRuleByHandle(t *Table, c *Chain, handle uint64) (*Rule, error) {
83+
rules, err := cc.getRules(t, c, unix.NFT_MSG_GETRULE, handle)
84+
if err != nil {
85+
return nil, err
86+
}
87+
88+
if got, want := len(rules), 1; got != want {
89+
return nil, fmt.Errorf("expected rule count %d, got %d", want, got)
90+
}
91+
92+
return rules[0], nil
93+
}
94+
7995
// GetRules returns the rules in the specified table and chain.
8096
func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) {
97+
return cc.getRules(t, c, unix.NFT_MSG_GETRULE, 0)
98+
}
99+
100+
// ResetRule resets the stateful expressions (e.g., counters) of the given
101+
// rule. The reset is applied immediately (no Flush is required). The returned
102+
// rule reflects its state prior to the reset. The provided rule must have a
103+
// valid Handle.
104+
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule-reset
105+
func (cc *Conn) ResetRule(t *Table, c *Chain, handle uint64) (*Rule, error) {
106+
if handle == 0 {
107+
return nil, fmt.Errorf("rule must have a valid handle")
108+
}
109+
110+
rules, err := cc.getRules(t, c, unix.NFT_MSG_GETRULE_RESET, handle)
111+
if err != nil {
112+
return nil, err
113+
}
114+
115+
if got, want := len(rules), 1; got != want {
116+
return nil, fmt.Errorf("expected rule count %d, got %d", want, got)
117+
}
118+
119+
return rules[0], nil
120+
}
121+
122+
// ResetRules resets the stateful expressions (e.g., counters) of all rules
123+
// in the given table and chain. The reset is applied immediately (no Flush
124+
// is required). The returned rules reflect their state prior to the reset.
125+
// state.
126+
// https://docs.kernel.org/networking/netlink_spec/nftables.html#getrule-reset
127+
func (cc *Conn) ResetRules(t *Table, c *Chain) ([]*Rule, error) {
128+
return cc.getRules(t, c, unix.NFT_MSG_GETRULE_RESET, 0)
129+
}
130+
131+
// getRules retrieves rules from the given table and chain, using the provided
132+
// msgType (either unix.NFT_MSG_GETRULE or unix.NFT_MSG_GETRULE_RESET). If the
133+
// handle is non-zero, the operation applies only to the rule with that handle.
134+
func (cc *Conn) getRules(t *Table, c *Chain, msgType int, handle uint64) ([]*Rule, error) {
81135
conn, closer, err := cc.netlinkConn()
82136
if err != nil {
83137
return nil, err
84138
}
85139
defer func() { _ = closer() }()
86140

87-
data, err := netlink.MarshalAttributes([]netlink.Attribute{
141+
attrs := []netlink.Attribute{
88142
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
89143
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")},
90-
})
144+
}
145+
146+
var flags netlink.HeaderFlags = netlink.Request | netlink.Acknowledge | netlink.Dump
147+
148+
if handle != 0 {
149+
attrs = append(attrs, netlink.Attribute{
150+
Type: unix.NFTA_RULE_HANDLE,
151+
Data: binaryutil.BigEndian.PutUint64(handle),
152+
})
153+
154+
flags = netlink.Request | netlink.Acknowledge
155+
}
156+
157+
data, err := netlink.MarshalAttributes(attrs)
91158
if err != nil {
92159
return nil, err
93160
}
94161

95162
message := netlink.Message{
96163
Header: netlink.Header{
97-
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE),
98-
Flags: netlink.Request | netlink.Acknowledge | netlink.Dump,
164+
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | msgType),
165+
Flags: flags,
99166
},
100167
Data: append(extraHeader(uint8(t.Family), 0), data...),
101168
}

0 commit comments

Comments
 (0)