Skip to content

Commit 207a463

Browse files
authored
Set rule handle during flush (#299)
This change makes it possible to delete rules after inserting them, without needing to query the rules first. Additionally, this allows positioning a new rule next to an existing rule. There are two ways to refer to a rule: Either by ID or by handle. The ID is assigned by userspace, and is only valid within a transaction, so it can only be used before the flush. The handle is assigned by the kernel when the transaction is committed, and can thus only be used after the flush. We thus need to set an ID on each newly created rule, and retrieve the handle of the rule during the flush. I extended the message struct with a pointer to the Rule which the message creates. This allows calling the reply handler callback which sets the handle. I updated tests to add a handle to generated replies for the NFT_MSG_NEWRULE messages.
1 parent 9a2862f commit 207a463

File tree

10 files changed

+101
-32
lines changed

10 files changed

+101
-32
lines changed

chain.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ func (cc *Conn) AddChain(c *Chain) *Chain {
140140
{Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")},
141141
})...)
142142
}
143-
cc.messages = append(cc.messages, netlink.Message{
143+
cc.messages = append(cc.messages, netlinkMessage{
144144
Header: netlink.Header{
145145
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN),
146146
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@@ -161,7 +161,7 @@ func (cc *Conn) DelChain(c *Chain) {
161161
{Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")},
162162
})
163163

164-
cc.messages = append(cc.messages, netlink.Message{
164+
cc.messages = append(cc.messages, netlinkMessage{
165165
Header: netlink.Header{
166166
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN),
167167
Flags: netlink.Request | netlink.Acknowledge,
@@ -179,7 +179,7 @@ func (cc *Conn) FlushChain(c *Chain) {
179179
{Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")},
180180
{Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")},
181181
})
182-
cc.messages = append(cc.messages, netlink.Message{
182+
cc.messages = append(cc.messages, netlinkMessage{
183183
Header: netlink.Header{
184184
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE),
185185
Flags: netlink.Request | netlink.Acknowledge,

conn.go

+28-13
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,20 @@ type Conn struct {
4141

4242
lasting bool // establish a lasting connection to be used across multiple netlink operations.
4343
mu sync.Mutex // protects the following state
44-
messages []netlink.Message
44+
messages []netlinkMessage
4545
err error
4646
nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol.
4747
sockOptions []SockOption
4848
lastID uint32
4949
allocatedIDs uint32
5050
}
5151

52+
type netlinkMessage struct {
53+
Header netlink.Header
54+
Data []byte
55+
rule *Rule
56+
}
57+
5258
// ConnOption is an option to change the behavior of the nftables Conn returned by Open.
5359
type ConnOption func(*Conn)
5460

@@ -268,6 +274,11 @@ func (cc *Conn) Flush() error {
268274
} else if replyIndex < len(cc.messages) {
269275
msg := messages[replyIndex+1]
270276
if msg.Header.Sequence == reply.Header.Sequence && msg.Header.Type == reply.Header.Type {
277+
// The only messages which set the echo flag are rule create messages.
278+
err := cc.messages[replyIndex].rule.handleCreateReply(reply)
279+
if err != nil {
280+
errs = errors.Join(errs, err)
281+
}
271282
replyIndex++
272283
for replyIndex < len(cc.messages) && cc.messages[replyIndex].Header.Flags&netlink.Echo == 0 {
273284
replyIndex++
@@ -309,7 +320,7 @@ func (cc *Conn) Flush() error {
309320
func (cc *Conn) FlushRuleset() {
310321
cc.mu.Lock()
311322
defer cc.mu.Unlock()
312-
cc.messages = append(cc.messages, netlink.Message{
323+
cc.messages = append(cc.messages, netlinkMessage{
313324
Header: netlink.Header{
314325
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
315326
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@@ -368,26 +379,30 @@ func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
368379
return b
369380
}
370381

371-
func batch(messages []netlink.Message) []netlink.Message {
372-
batch := []netlink.Message{
373-
{
374-
Header: netlink.Header{
375-
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
376-
Flags: netlink.Request,
377-
},
378-
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
382+
func batch(messages []netlinkMessage) []netlink.Message {
383+
batch := make([]netlink.Message, len(messages)+2)
384+
batch[0] = netlink.Message{
385+
Header: netlink.Header{
386+
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
387+
Flags: netlink.Request,
379388
},
389+
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
380390
}
381391

382-
batch = append(batch, messages...)
392+
for i, msg := range messages {
393+
batch[i+1] = netlink.Message{
394+
Header: msg.Header,
395+
Data: msg.Data,
396+
}
397+
}
383398

384-
batch = append(batch, netlink.Message{
399+
batch[len(messages)+1] = netlink.Message{
385400
Header: netlink.Header{
386401
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
387402
Flags: netlink.Request,
388403
},
389404
Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES),
390-
})
405+
}
391406

392407
return batch
393408
}

flowtable.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ func (cc *Conn) AddFlowtable(f *Flowtable) *Flowtable {
142142
{Type: unix.NLA_F_NESTED | NFTA_FLOWTABLE_HOOK, Data: cc.marshalAttr(hookAttr)},
143143
})...)
144144

145-
cc.messages = append(cc.messages, netlink.Message{
145+
cc.messages = append(cc.messages, netlinkMessage{
146146
Header: netlink.Header{
147147
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWFLOWTABLE),
148148
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@@ -162,7 +162,7 @@ func (cc *Conn) DelFlowtable(f *Flowtable) {
162162
{Type: NFTA_FLOWTABLE_NAME, Data: []byte(f.Name)},
163163
})
164164

165-
cc.messages = append(cc.messages, netlink.Message{
165+
cc.messages = append(cc.messages, netlinkMessage{
166166
Header: netlink.Header{
167167
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_DELFLOWTABLE),
168168
Flags: netlink.Request | netlink.Acknowledge,

internal/nftest/nftest.go

+11
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import (
88
"testing"
99

1010
"github.com/google/nftables"
11+
"github.com/google/nftables/binaryutil"
1112
"github.com/mdlayher/netlink"
13+
"golang.org/x/sys/unix"
1214
)
1315

1416
// Recorder provides an nftables connection that does not send to the Linux
@@ -21,6 +23,7 @@ type Recorder struct {
2123
// Conn opens an nftables connection that records netlink messages into the
2224
// Recorder.
2325
func (r *Recorder) Conn() (*nftables.Conn, error) {
26+
nextHandle := uint64(1)
2427
return nftables.New(nftables.WithTestDial(
2528
func(req []netlink.Message) ([]netlink.Message, error) {
2629
r.requests = append(r.requests, req...)
@@ -30,6 +33,14 @@ func (r *Recorder) Conn() (*nftables.Conn, error) {
3033
for _, msg := range req {
3134
if msg.Header.Flags&netlink.Echo != 0 {
3235
data := append([]byte{}, msg.Data...)
36+
switch msg.Header.Type {
37+
case netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE):
38+
attrs, _ := netlink.MarshalAttributes([]netlink.Attribute{
39+
{Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(nextHandle)},
40+
})
41+
nextHandle++
42+
data = append(data, attrs...)
43+
}
3344
replies = append(replies, netlink.Message{
3445
Header: msg.Header,
3546
Data: data,

nftables_test.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ func linediff(a, b string) string {
8080
}
8181

8282
func expectMessages(t *testing.T, want [][]byte) nftables.ConnOption {
83+
nextHandle := uint64(1)
8384
return nftables.WithTestDial(func(req []netlink.Message) ([]netlink.Message, error) {
8485
var replies []netlink.Message
8586
for idx, msg := range req {
@@ -103,6 +104,14 @@ func expectMessages(t *testing.T, want [][]byte) nftables.ConnOption {
103104
// Generate replies.
104105
if msg.Header.Flags&netlink.Echo != 0 {
105106
data := append([]byte{}, msg.Data...)
107+
switch msg.Header.Type {
108+
case netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE):
109+
attrs, _ := netlink.MarshalAttributes([]netlink.Attribute{
110+
{Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(nextHandle)},
111+
})
112+
nextHandle++
113+
data = append(data, attrs...)
114+
}
106115
replies = append(replies, netlink.Message{
107116
Header: msg.Header,
108117
Data: data,
@@ -316,7 +325,7 @@ func TestRuleHandle(t *testing.T) {
316325
}
317326

318327
for _, tt := range tests {
319-
for _, afterFlush := range []bool{false} {
328+
for _, afterFlush := range []bool{false, true} {
320329
flushName := map[bool]string{
321330
false: "-before-flush",
322331
true: "-after-flush",

obj.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ func (cc *Conn) AddObj(o Obj) Obj {
124124
attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: data})
125125
}
126126

127-
cc.messages = append(cc.messages, netlink.Message{
127+
cc.messages = append(cc.messages, netlinkMessage{
128128
Header: netlink.Header{
129129
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ),
130130
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@@ -146,7 +146,7 @@ func (cc *Conn) DeleteObject(o Obj) {
146146
data := cc.marshalAttr(attrs)
147147
data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...)
148148

149-
cc.messages = append(cc.messages, netlink.Message{
149+
cc.messages = append(cc.messages, netlinkMessage{
150150
Header: netlink.Header{
151151
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ),
152152
Flags: netlink.Request | netlink.Acknowledge,

rule.go

+34-3
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@ const (
4848
type Rule struct {
4949
Table *Table
5050
Chain *Chain
51-
// Handle identifies an existing Rule.
51+
// Handle identifies an existing Rule. For a new Rule, this field is set
52+
// during the Flush() in which the rule is committed. Make sure to not access
53+
// this field concurrently with this Flush() to avoid data races.
5254
Handle uint64
5355
// ID is an identifier for a new Rule, which is assigned by
5456
// AddRule/InsertRule, and only valid before the rule is committed by Flush().
57+
// The field is set to 0 during Flush().
5558
ID uint32
5659
// Position can be set to the Handle of another Rule to insert the new Rule
5760
// before (InsertRule) or after (AddRule) the existing rule.
@@ -171,11 +174,14 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
171174
}
172175

173176
var flags netlink.HeaderFlags
177+
var ruleRef *Rule
174178
switch op {
175179
case operationAdd:
176180
flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo | netlink.Append
181+
ruleRef = r
177182
case operationInsert:
178183
flags = netlink.Request | netlink.Acknowledge | netlink.Create | netlink.Echo
184+
ruleRef = r
179185
case operationReplace:
180186
flags = netlink.Request | netlink.Acknowledge | netlink.Replace
181187
}
@@ -190,17 +196,42 @@ func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule {
190196
})...)
191197
}
192198

193-
cc.messages = append(cc.messages, netlink.Message{
199+
cc.messages = append(cc.messages, netlinkMessage{
194200
Header: netlink.Header{
195201
Type: newRuleHeaderType,
196202
Flags: flags,
197203
},
198204
Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...),
205+
rule: ruleRef,
199206
})
200207

201208
return r
202209
}
203210

211+
func (r *Rule) handleCreateReply(reply netlink.Message) error {
212+
ad, err := netlink.NewAttributeDecoder(reply.Data[4:])
213+
if err != nil {
214+
return err
215+
}
216+
ad.ByteOrder = binary.BigEndian
217+
var handle uint64
218+
for ad.Next() {
219+
switch ad.Type() {
220+
case unix.NFTA_RULE_HANDLE:
221+
handle = ad.Uint64()
222+
}
223+
}
224+
if ad.Err() != nil {
225+
return ad.Err()
226+
}
227+
if handle == 0 {
228+
return fmt.Errorf("missing rule handle in create reply")
229+
}
230+
r.Handle = handle
231+
r.ID = 0
232+
return nil
233+
}
234+
204235
func (cc *Conn) ReplaceRule(r *Rule) *Rule {
205236
return cc.newRule(r, operationReplace)
206237
}
@@ -247,7 +278,7 @@ func (cc *Conn) DelRule(r *Rule) error {
247278
}
248279
flags := netlink.Request | netlink.Acknowledge
249280

250-
cc.messages = append(cc.messages, netlink.Message{
281+
cc.messages = append(cc.messages, netlinkMessage{
251282
Header: netlink.Header{
252283
Type: delRuleHeaderType,
253284
Flags: flags,

set.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ func (cc *Conn) appendElemList(s *Set, vals []SetElement, hdrType uint16) error
506506
{Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem},
507507
}
508508

509-
cc.messages = append(cc.messages, netlink.Message{
509+
cc.messages = append(cc.messages, netlinkMessage{
510510
Header: netlink.Header{
511511
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType),
512512
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@@ -680,7 +680,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
680680
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data})
681681
}
682682

683-
cc.messages = append(cc.messages, netlink.Message{
683+
cc.messages = append(cc.messages, netlinkMessage{
684684
Header: netlink.Header{
685685
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET),
686686
Flags: netlink.Request | netlink.Acknowledge | netlink.Create,
@@ -700,7 +700,7 @@ func (cc *Conn) DelSet(s *Set) {
700700
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
701701
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
702702
})
703-
cc.messages = append(cc.messages, netlink.Message{
703+
cc.messages = append(cc.messages, netlinkMessage{
704704
Header: netlink.Header{
705705
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET),
706706
Flags: netlink.Request | netlink.Acknowledge,
@@ -717,7 +717,7 @@ func (cc *Conn) FlushSet(s *Set) {
717717
{Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")},
718718
{Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")},
719719
})
720-
cc.messages = append(cc.messages, netlink.Message{
720+
cc.messages = append(cc.messages, netlinkMessage{
721721
Header: netlink.Header{
722722
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM),
723723
Flags: netlink.Request | netlink.Acknowledge,

set_test.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,10 @@ func TestMarshalSet(t *testing.T) {
254254
}
255255
msg := c.messages[connMsgSetIdx]
256256

257-
nset, err := setsFromMsg(msg)
257+
nset, err := setsFromMsg(netlink.Message{
258+
Header: msg.Header,
259+
Data: msg.Data,
260+
})
258261
if err != nil {
259262
t.Fatalf("setsFromMsg() error: %+v", err)
260263
}

table.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func (cc *Conn) DelTable(t *Table) {
5757
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
5858
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
5959
})
60-
cc.messages = append(cc.messages, netlink.Message{
60+
cc.messages = append(cc.messages, netlinkMessage{
6161
Header: netlink.Header{
6262
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
6363
Flags: netlink.Request | netlink.Acknowledge,
@@ -73,7 +73,7 @@ func (cc *Conn) addTable(t *Table, flag netlink.HeaderFlags) *Table {
7373
{Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")},
7474
{Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}},
7575
})
76-
cc.messages = append(cc.messages, netlink.Message{
76+
cc.messages = append(cc.messages, netlinkMessage{
7777
Header: netlink.Header{
7878
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE),
7979
Flags: netlink.Request | netlink.Acknowledge | flag,
@@ -103,7 +103,7 @@ func (cc *Conn) FlushTable(t *Table) {
103103
data := cc.marshalAttr([]netlink.Attribute{
104104
{Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")},
105105
})
106-
cc.messages = append(cc.messages, netlink.Message{
106+
cc.messages = append(cc.messages, netlinkMessage{
107107
Header: netlink.Header{
108108
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE),
109109
Flags: netlink.Request | netlink.Acknowledge,

0 commit comments

Comments
 (0)