cmd/containerboot,util/linuxfw: create a SNAT rule for dst/src only once, clean up if needed (#13658)

The AddSNATRuleForDst rule was adding a new rule each time it was called including:
- if a rule already existed
- if a rule matching the destination, but with different desired source already existed

This was causing issues especially for the in-progress egress HA proxies work,
where the rules are now refreshed more frequently, so more redundant rules
were being created.

This change:
- only creates the rule if it doesn't already exist
- if a rule for the same dst, but different source is found, delete it
- also ensures that egress proxies refresh firewall rules
if the node's tailnet IP changes

Updates tailscale/tailscale#13406

Signed-off-by: Irbe Krumina <irbe@tailscale.com>
This commit is contained in:
Irbe Krumina
2024-10-03 20:15:00 +01:00
committed by GitHub
parent a3c6a3a34f
commit 9bd158cc09
9 changed files with 272 additions and 73 deletions

View File

@ -27,32 +27,32 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) {
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP)
svcChains(t, 1, conn)
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4)
chainRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
checkPortMapRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
// Create another rule for service 'foo' to forward TCP traffic to the
// same IPv4 endpoint, but to a different port.
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1)
svcChains(t, 1, conn)
chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4)
chainRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
checkPortMapRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
// Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP)
svcChains(t, 2, conn)
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6)
chainRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
checkPortMapRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
// Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP)
svcChains(t, 3, conn)
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4)
chainRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
checkPortMapRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
// Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP)
svcChains(t, 4, conn)
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6)
chainRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
checkPortMapRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
// Delete service bar
runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP})
@ -95,36 +95,26 @@ func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) {
}
}
// chainRuleCount returns number of rules in a chain identified by service name and IP family.
func chainRuleCount(t *testing.T, svc string, count int, conn *nftables.Conn, fam nftables.TableFamily) {
// chainRuleCount verifies that the named chain in the given table contains the provided number of rules.
func chainRuleCount(t *testing.T, name string, numOfRules int, conn *nftables.Conn, fam nftables.TableFamily) {
t.Helper()
chains, err := conn.ListChainsOfTableFamily(fam)
if err != nil {
t.Fatalf("error listing chains: %v", err)
}
found := false
for _, ch := range chains {
if ch.Name == svc {
found = true
rules, err := conn.GetRules(ch.Table, ch)
if err != nil {
t.Fatalf("error getting rules: %v", err)
}
if len(rules) != count {
t.Fatalf("unexpected number of rules, wants %d got %d", count, len(rules))
}
break
if ch.Name == name {
checkChainRules(t, conn, ch, numOfRules)
return
}
}
if !found {
t.Fatalf("chain for service %s does not exist", svc)
}
t.Fatalf("chain %s does not exist", name)
}
// chainRule verifies that rule for the provided target IP and PortMap exists in
// a chain identified by service name and IP family.
func chainRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner *nftablesRunner, fam nftables.TableFamily) {
// checkPortMapRule verifies that rule for the provided target IP and PortMap exists in a chain identified by service
// name and IP family.
func checkPortMapRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner *nftablesRunner, fam nftables.TableFamily) {
t.Helper()
chains, err := runner.conn.ListChainsOfTableFamily(fam)
if err != nil {
@ -146,11 +136,17 @@ func chainRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner
t.Fatalf("error converting protocol: %v", err)
}
wantsRule := portMapRule(chain.Table, chain, "tailscale0", targetIP, pm.MatchPort, pm.TargetPort, p, meta)
gotRule, err := findRule(runner.conn, wantsRule)
checkRule(t, wantsRule, runner.conn)
}
// checkRule checks that the provided rules exists.
func checkRule(t *testing.T, rule *nftables.Rule, conn *nftables.Conn) {
t.Helper()
gotRule, err := findRule(conn, rule)
if err != nil {
t.Fatalf("error looking up rule: %v", err)
}
if gotRule == nil {
t.Fatalf("rule not found")
t.Fatal("rule not found")
}
}