cmd/containerboot,util/linuxfw: create a SNAT rule for dst/src only once, clean up if needed (#13658)
The AddSNATRuleForDst rule was adding a new rule each time it was called including: - if a rule already existed - if a rule matching the destination, but with different desired source already existed This was causing issues especially for the in-progress egress HA proxies work, where the rules are now refreshed more frequently, so more redundant rules were being created. This change: - only creates the rule if it doesn't already exist - if a rule for the same dst, but different source is found, delete it - also ensures that egress proxies refresh firewall rules if the node's tailnet IP changes Updates tailscale/tailscale#13406 Signed-off-by: Irbe Krumina <irbe@tailscale.com>
This commit is contained in:
@ -27,32 +27,32 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) {
|
||||
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP)
|
||||
svcChains(t, 1, conn)
|
||||
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4)
|
||||
chainRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
|
||||
checkPortMapRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
|
||||
|
||||
// Create another rule for service 'foo' to forward TCP traffic to the
|
||||
// same IPv4 endpoint, but to a different port.
|
||||
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1)
|
||||
svcChains(t, 1, conn)
|
||||
chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4)
|
||||
chainRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
|
||||
checkPortMapRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
|
||||
|
||||
// Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint
|
||||
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP)
|
||||
svcChains(t, 2, conn)
|
||||
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6)
|
||||
chainRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
|
||||
checkPortMapRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
|
||||
|
||||
// Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint
|
||||
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP)
|
||||
svcChains(t, 3, conn)
|
||||
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4)
|
||||
chainRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
|
||||
checkPortMapRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
|
||||
|
||||
// Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint
|
||||
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP)
|
||||
svcChains(t, 4, conn)
|
||||
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6)
|
||||
chainRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
|
||||
checkPortMapRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
|
||||
|
||||
// Delete service bar
|
||||
runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP})
|
||||
@ -95,36 +95,26 @@ func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
// chainRuleCount returns number of rules in a chain identified by service name and IP family.
|
||||
func chainRuleCount(t *testing.T, svc string, count int, conn *nftables.Conn, fam nftables.TableFamily) {
|
||||
// chainRuleCount verifies that the named chain in the given table contains the provided number of rules.
|
||||
func chainRuleCount(t *testing.T, name string, numOfRules int, conn *nftables.Conn, fam nftables.TableFamily) {
|
||||
t.Helper()
|
||||
chains, err := conn.ListChainsOfTableFamily(fam)
|
||||
if err != nil {
|
||||
t.Fatalf("error listing chains: %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, ch := range chains {
|
||||
if ch.Name == svc {
|
||||
found = true
|
||||
rules, err := conn.GetRules(ch.Table, ch)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting rules: %v", err)
|
||||
}
|
||||
if len(rules) != count {
|
||||
t.Fatalf("unexpected number of rules, wants %d got %d", count, len(rules))
|
||||
}
|
||||
break
|
||||
if ch.Name == name {
|
||||
checkChainRules(t, conn, ch, numOfRules)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("chain for service %s does not exist", svc)
|
||||
}
|
||||
t.Fatalf("chain %s does not exist", name)
|
||||
}
|
||||
|
||||
// chainRule verifies that rule for the provided target IP and PortMap exists in
|
||||
// a chain identified by service name and IP family.
|
||||
func chainRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner *nftablesRunner, fam nftables.TableFamily) {
|
||||
// checkPortMapRule verifies that rule for the provided target IP and PortMap exists in a chain identified by service
|
||||
// name and IP family.
|
||||
func checkPortMapRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner *nftablesRunner, fam nftables.TableFamily) {
|
||||
t.Helper()
|
||||
chains, err := runner.conn.ListChainsOfTableFamily(fam)
|
||||
if err != nil {
|
||||
@ -146,11 +136,17 @@ func chainRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner
|
||||
t.Fatalf("error converting protocol: %v", err)
|
||||
}
|
||||
wantsRule := portMapRule(chain.Table, chain, "tailscale0", targetIP, pm.MatchPort, pm.TargetPort, p, meta)
|
||||
gotRule, err := findRule(runner.conn, wantsRule)
|
||||
checkRule(t, wantsRule, runner.conn)
|
||||
}
|
||||
|
||||
// checkRule checks that the provided rules exists.
|
||||
func checkRule(t *testing.T, rule *nftables.Rule, conn *nftables.Conn) {
|
||||
t.Helper()
|
||||
gotRule, err := findRule(conn, rule)
|
||||
if err != nil {
|
||||
t.Fatalf("error looking up rule: %v", err)
|
||||
}
|
||||
if gotRule == nil {
|
||||
t.Fatalf("rule not found")
|
||||
t.Fatal("rule not found")
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user