util/linuxfw: decoupling IPTables logic from linux router

This change is introducing new netfilterRunner interface and moving iptables manipulation to a lower leveled iptables runner.

For #391

Signed-off-by: KevinLiang10 <kevinliang@tailscale.com>
This commit is contained in:
KevinLiang10
2023-06-16 18:54:58 +00:00
committed by KevinLiang10
parent 9c64e015e5
commit 243ce6ccc1
11 changed files with 1469 additions and 713 deletions

View File

@ -4,7 +4,6 @@
package router
import (
"bytes"
"errors"
"fmt"
"net"
@ -17,7 +16,6 @@ import (
"syscall"
"time"
"github.com/coreos/go-iptables/iptables"
"github.com/tailscale/netlink"
"github.com/tailscale/wireguard-go/tun"
"go4.org/netipx"
@ -25,9 +23,9 @@ import (
"golang.org/x/time/rate"
"tailscale.com/envknob"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"tailscale.com/types/preftype"
"tailscale.com/util/linuxfw"
"tailscale.com/util/multierr"
"tailscale.com/version/distro"
)
@ -38,56 +36,34 @@ const (
netfilterOn = preftype.NetfilterOn
)
// The following bits are added to packet marks for Tailscale use.
//
// We tried to pick bits sufficiently out of the way that it's
// unlikely to collide with existing uses. We have 4 bytes of mark
// bits to play with. We leave the lower byte alone on the assumption
// that sysadmins would use those. Kubernetes uses a few bits in the
// second byte, so we steer clear of that too.
//
// Empirically, most of the documentation on packet marks on the
// internet gives the impression that the marks are 16 bits
// wide. Based on this, we theorize that the upper two bytes are
// relatively unused in the wild, and so we consume bits 16:23 (the
// third byte).
//
// The constants are in the iptables/iproute2 string format for
// matching and setting the bits, so they can be directly embedded in
// commands.
const (
// The mask for reading/writing the 'firewall mask' bits on a packet.
// See the comment on the const block on why we only use the third byte.
//
// We claim bits 16:23 entirely. For now we only use the lower four
// bits, leaving the higher 4 bits for future use.
tailscaleFwmarkMask = "0xff0000"
tailscaleFwmarkMaskNum = 0xff0000
// Packet is from Tailscale and to a subnet route destination, so
// is allowed to be routed through this machine.
tailscaleSubnetRouteMark = "0x40000"
// Packet was originated by tailscaled itself, and must not be
// routed over the Tailscale network.
//
// Keep this in sync with tailscaleBypassMark in
// net/netns/netns_linux.go.
tailscaleBypassMark = "0x80000"
tailscaleBypassMarkNum = 0x80000
)
// netfilterRunner abstracts helpers to run netfilter commands. It
// exists purely to swap out go-iptables for a fake implementation in
// tests.
type netfilterRunner interface {
Insert(table, chain string, pos int, args ...string) error
Append(table, chain string, args ...string) error
Exists(table, chain string, args ...string) (bool, error)
Delete(table, chain string, args ...string) error
ClearChain(table, chain string) error
NewChain(table, chain string) error
DeleteChain(table, chain string) error
AddLoopbackRule(addr netip.Addr) error
DelLoopbackRule(addr netip.Addr) error
AddHooks() error
DelHooks(logf logger.Logf) error
AddChains() error
DelChains() error
AddBase(tunname string) error
DelBase() error
AddSNATRule() error
DelSNATRule() error
HasIPV6() bool
HasIPV6NAT() bool
}
func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) {
var nfr netfilterRunner
var err error
nfr, err = linuxfw.NewIPTablesRunner(logf)
if err != nil {
return nil, err
}
return nfr, nil
}
type linuxRouter struct {
@ -109,16 +85,13 @@ type linuxRouter struct {
// Various feature checks for the network stack.
ipRuleAvailable bool // whether kernel was built with IP_MULTIPLE_TABLES
v6Available bool
v6NATAvailable bool
fwmaskWorks bool // whether we can use 'ip rule...fwmark <mark>/<mask>'
// ipPolicyPrefBase is the base priority at which ip rules are installed.
ipPolicyPrefBase int
ipt4 netfilterRunner
ipt6 netfilterRunner
cmd commandRunner
nfr netfilterRunner
cmd commandRunner
}
func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Monitor) (Router, error) {
@ -127,51 +100,27 @@ func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Moni
return nil, err
}
ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
nfr, err := newNetfilterRunner(logf)
if err != nil {
return nil, err
}
v6err := checkIPv6(logf)
if v6err != nil {
logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
}
supportsV6 := v6err == nil
supportsV6NAT := supportsV6 && supportsV6NAT()
if supportsV6 {
logf("v6nat = %v", supportsV6NAT)
}
var ipt6 netfilterRunner
if supportsV6 {
// The iptables package probes for `ip6tables` and errors out
// if unavailable. We want that to be a non-fatal error.
ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
return nil, err
}
}
cmd := osCommandRunner{
ambientCapNetAdmin: useAmbientCaps(),
}
return newUserspaceRouterAdvanced(logf, tunname, netMon, ipt4, ipt6, cmd, supportsV6, supportsV6NAT)
return newUserspaceRouterAdvanced(logf, tunname, netMon, nfr, cmd)
}
func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, netfilter4, netfilter6 netfilterRunner, cmd commandRunner, supportsV6, supportsV6NAT bool) (Router, error) {
func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, nfr netfilterRunner, cmd commandRunner) (Router, error) {
r := &linuxRouter{
logf: logf,
tunname: tunname,
netfilterMode: netfilterOff,
netMon: netMon,
v6Available: supportsV6,
v6NATAvailable: supportsV6NAT,
ipt4: netfilter4,
ipt6: netfilter6,
cmd: cmd,
nfr: nfr,
cmd: cmd,
ipRuleFixLimiter: rate.NewLimiter(rate.Every(5*time.Second), 10),
ipPolicyPrefBase: 5200,
@ -484,23 +433,23 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
case netfilterOff:
switch r.netfilterMode {
case netfilterNoDivert:
if err := r.delNetfilterBase(); err != nil {
if err := r.nfr.DelBase(); err != nil {
return err
}
if err := r.delNetfilterChains(); err != nil {
if err := r.nfr.DelChains(); err != nil {
r.logf("note: %v", err)
// harmless, continue.
// This can happen if someone left a ref to
// this table somewhere else.
}
case netfilterOn:
if err := r.delNetfilterHooks(); err != nil {
if err := r.nfr.DelHooks(r.logf); err != nil {
return err
}
if err := r.delNetfilterBase(); err != nil {
if err := r.nfr.DelBase(); err != nil {
return err
}
if err := r.delNetfilterChains(); err != nil {
if err := r.nfr.DelChains(); err != nil {
r.logf("note: %v", err)
// harmless, continue.
// This can happen if someone left a ref to
@ -512,15 +461,15 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
switch r.netfilterMode {
case netfilterOff:
reprocess = true
if err := r.addNetfilterChains(); err != nil {
if err := r.nfr.AddChains(); err != nil {
return err
}
if err := r.addNetfilterBase(); err != nil {
if err := r.nfr.AddBase(r.tunname); err != nil {
return err
}
r.snatSubnetRoutes = false
case netfilterOn:
if err := r.delNetfilterHooks(); err != nil {
if err := r.nfr.DelHooks(r.logf); err != nil {
return err
}
}
@ -529,33 +478,33 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
// we can't add a "-j ts-forward" rule to FORWARD
// while ts-forward contains an "-m mark" rule. But
// we can add the row *before* populating ts-forward.
// So we have to delNetFilterBase, then add the hooks,
// then re-addNetFilterBase, just in case.
// So we have to delBase, then add the hooks,
// then re-addBase, just in case.
switch r.netfilterMode {
case netfilterOff:
reprocess = true
if err := r.addNetfilterChains(); err != nil {
if err := r.nfr.AddChains(); err != nil {
return err
}
if err := r.delNetfilterBase(); err != nil {
if err := r.nfr.DelBase(); err != nil {
return err
}
if err := r.addNetfilterHooks(); err != nil {
if err := r.nfr.AddHooks(); err != nil {
return err
}
if err := r.addNetfilterBase(); err != nil {
if err := r.nfr.AddBase(r.tunname); err != nil {
return err
}
r.snatSubnetRoutes = false
case netfilterNoDivert:
reprocess = true
if err := r.delNetfilterBase(); err != nil {
if err := r.nfr.DelBase(); err != nil {
return err
}
if err := r.addNetfilterHooks(); err != nil {
if err := r.nfr.AddHooks(); err != nil {
return err
}
if err := r.addNetfilterBase(); err != nil {
if err := r.nfr.AddBase(r.tunname); err != nil {
return err
}
r.snatSubnetRoutes = false
@ -579,11 +528,19 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
return nil
}
func (r *linuxRouter) getV6Available() bool {
return r.nfr.HasIPV6()
}
func (r *linuxRouter) getV6NATAvailable() bool {
return r.nfr.HasIPV6NAT()
}
// addAddress adds an IP/mask to the tunnel interface. Fails if the
// address is already assigned to the interface, or if the addition
// fails.
func (r *linuxRouter) addAddress(addr netip.Prefix) error {
if !r.v6Available && addr.Addr().Is6() {
if !r.getV6Available() && addr.Addr().Is6() {
return nil
}
if r.useIPCommand() {
@ -609,7 +566,7 @@ func (r *linuxRouter) addAddress(addr netip.Prefix) error {
// the address is not assigned to the interface, or if the removal
// fails.
func (r *linuxRouter) delAddress(addr netip.Prefix) error {
if !r.v6Available && addr.Addr().Is6() {
if !r.getV6Available() && addr.Addr().Is6() {
return nil
}
if err := r.delLoopbackRule(addr.Addr()); err != nil {
@ -638,17 +595,8 @@ func (r *linuxRouter) addLoopbackRule(addr netip.Addr) error {
return nil
}
nf := r.ipt4
if addr.Is6() {
if !r.v6Available {
// IPv6 not available, ignore.
return nil
}
nf = r.ipt6
}
if err := nf.Insert("filter", "ts-input", 1, "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil {
return fmt.Errorf("adding loopback allow rule for %q: %w", addr, err)
if err := r.nfr.AddLoopbackRule(addr); err != nil {
return err
}
return nil
}
@ -660,17 +608,8 @@ func (r *linuxRouter) delLoopbackRule(addr netip.Addr) error {
return nil
}
nf := r.ipt4
if addr.Is6() {
if !r.v6Available {
// IPv6 not available, ignore.
return nil
}
nf = r.ipt6
}
if err := nf.Delete("filter", "ts-input", "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil {
return fmt.Errorf("deleting loopback allow rule for %q: %w", addr, err)
if err := r.nfr.DelLoopbackRule(addr); err != nil {
return err
}
return nil
}
@ -679,7 +618,7 @@ func (r *linuxRouter) delLoopbackRule(addr netip.Addr) error {
// interface. Fails if the route already exists, or if adding the
// route fails.
func (r *linuxRouter) addRoute(cidr netip.Prefix) error {
if !r.v6Available && cidr.Addr().Is6() {
if !r.getV6Available() && cidr.Addr().Is6() {
return nil
}
if r.useIPCommand() {
@ -704,7 +643,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error {
if !r.ipRuleAvailable {
return nil
}
if !r.v6Available && cidr.Addr().Is6() {
if !r.getV6Available() && cidr.Addr().Is6() {
return nil
}
if r.useIPCommand() {
@ -712,7 +651,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error {
}
err := netlink.RouteReplace(&netlink.Route{
Dst: netipx.PrefixIPNet(cidr.Masked()),
Table: tailscaleRouteTable.num,
Table: tailscaleRouteTable.Num,
Type: unix.RTN_THROW,
})
if err != nil {
@ -722,7 +661,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error {
}
func (r *linuxRouter) addRouteDef(routeDef []string, cidr netip.Prefix) error {
if !r.v6Available && cidr.Addr().Is6() {
if !r.getV6Available() && cidr.Addr().Is6() {
return nil
}
args := append([]string{"ip", "route", "add"}, routeDef...)
@ -756,7 +695,7 @@ var (
// interface. Fails if the route doesn't exist, or if removing the
// route fails.
func (r *linuxRouter) delRoute(cidr netip.Prefix) error {
if !r.v6Available && cidr.Addr().Is6() {
if !r.getV6Available() && cidr.Addr().Is6() {
return nil
}
if r.useIPCommand() {
@ -784,7 +723,7 @@ func (r *linuxRouter) delThrowRoute(cidr netip.Prefix) error {
if !r.ipRuleAvailable {
return nil
}
if !r.v6Available && cidr.Addr().Is6() {
if !r.getV6Available() && cidr.Addr().Is6() {
return nil
}
if r.useIPCommand() {
@ -803,7 +742,7 @@ func (r *linuxRouter) delThrowRoute(cidr netip.Prefix) error {
}
func (r *linuxRouter) delRouteDef(routeDef []string, cidr netip.Prefix) error {
if !r.v6Available && cidr.Addr().Is6() {
if !r.getV6Available() && cidr.Addr().Is6() {
return nil
}
args := append([]string{"ip", "route", "del"}, routeDef...)
@ -865,7 +804,7 @@ func (r *linuxRouter) linkIndex() (int, error) {
// routeTable returns the route table to use.
func (r *linuxRouter) routeTable() int {
if r.ipRuleAvailable {
return tailscaleRouteTable.num
return tailscaleRouteTable.Num
}
return 0
}
@ -962,7 +901,7 @@ func (f addrFamily) netlinkInt() int {
}
func (r *linuxRouter) addrFamilies() []addrFamily {
if r.v6Available {
if r.getV6Available() {
return []addrFamily{v4, v6}
}
return []addrFamily{v4}
@ -985,30 +924,34 @@ func (r *linuxRouter) addIPRules() error {
return r.justAddIPRules()
}
// routeTable is a Linux routing table: both its name and number.
// RouteTable is a Linux routing table: both its name and number.
// See /etc/iproute2/rt_tables.
type routeTable struct {
name string
num int
type RouteTable struct {
Name string
Num int
}
// ipCmdArg returns the string form of the table to pass to the "ip" command.
func (rt routeTable) ipCmdArg() string {
if rt.num >= 253 {
return rt.name
var routeTableByNumber = map[int]RouteTable{}
// IpCmdArg returns the string form of the table to pass to the "ip" command.
func (rt RouteTable) ipCmdArg() string {
if rt.Num >= 253 {
return rt.Name
}
return strconv.Itoa(rt.num)
return strconv.Itoa(rt.Num)
}
var routeTableByNumber = map[int]routeTable{}
func newRouteTable(name string, num int) routeTable {
rt := routeTable{name, num}
func newRouteTable(name string, num int) RouteTable {
rt := RouteTable{name, num}
routeTableByNumber[num] = rt
return rt
}
func mustRouteTable(num int) routeTable {
// MustRouteTable returns the RouteTable with the given number key.
// It panics if the number is unknown because this result is a part
// of IP rule argument and we don't want to continue with an invalid
// argument with table no exist.
func mustRouteTable(num int) RouteTable {
rt, ok := routeTableByNumber[num]
if !ok {
panic(fmt.Sprintf("unknown route table %v", num))
@ -1059,22 +1002,22 @@ var ipRules = []netlink.Rule{
// main routing table.
{
Priority: 10,
Mark: tailscaleBypassMarkNum,
Table: mainRouteTable.num,
Mark: linuxfw.TailscaleBypassMarkNum,
Table: mainRouteTable.Num,
},
// ...and then we try the 'default' table, for correctness,
// even though it's been empty on every Linux system I've ever seen.
{
Priority: 30,
Mark: tailscaleBypassMarkNum,
Table: defaultRouteTable.num,
Mark: linuxfw.TailscaleBypassMarkNum,
Table: defaultRouteTable.Num,
},
// If neither of those matched (no default route on this system?)
// then packets from us should be aborted rather than falling through
// to the tailscale routes, because that would create routing loops.
{
Priority: 50,
Mark: tailscaleBypassMarkNum,
Mark: linuxfw.TailscaleBypassMarkNum,
Type: unix.RTN_UNREACHABLE,
},
// If we get to this point, capture all packets and send them
@ -1084,7 +1027,7 @@ var ipRules = []netlink.Rule{
// beat non-VPN routes.
{
Priority: 70,
Table: tailscaleRouteTable.num,
Table: tailscaleRouteTable.Num,
},
// If that didn't match, then non-fwmark packets fall through to the
// usual rules (pref 32766 and 32767, ie. main and default).
@ -1105,7 +1048,7 @@ func (r *linuxRouter) justAddIPRules() error {
// Note: r is a value type here; safe to mutate it.
ru.Family = family.netlinkInt()
if ru.Mark != 0 {
ru.Mask = tailscaleFwmarkMaskNum
ru.Mask = linuxfw.TailscaleFwmarkMaskNum
}
ru.Goto = -1
ru.SuppressIfgroup = -1
@ -1138,7 +1081,7 @@ func (r *linuxRouter) addIPRulesWithIPCommand() error {
}
if rule.Mark != 0 {
if r.fwmaskWorks {
args = append(args, "fwmark", fmt.Sprintf("0x%x/%s", rule.Mark, tailscaleFwmarkMask))
args = append(args, "fwmark", fmt.Sprintf("0x%x/%s", rule.Mark, linuxfw.TailscaleFwmarkMask))
} else {
args = append(args, "fwmark", fmt.Sprintf("0x%x", rule.Mark))
}
@ -1239,284 +1182,6 @@ func (r *linuxRouter) delIPRulesWithIPCommand() error {
return rg.ErrAcc
}
func (r *linuxRouter) netfilterFamilies() []netfilterRunner {
if r.v6Available {
return []netfilterRunner{r.ipt4, r.ipt6}
}
return []netfilterRunner{r.ipt4}
}
// addNetfilterChains creates custom Tailscale chains in netfilter.
func (r *linuxRouter) addNetfilterChains() error {
create := func(ipt netfilterRunner, table, chain string) error {
err := ipt.ClearChain(table, chain)
if errCode(err) == 1 {
// nonexistent chain. let's create it!
return ipt.NewChain(table, chain)
}
if err != nil {
return fmt.Errorf("setting up %s/%s: %w", table, chain, err)
}
return nil
}
for _, ipt := range r.netfilterFamilies() {
if err := create(ipt, "filter", "ts-input"); err != nil {
return err
}
if err := create(ipt, "filter", "ts-forward"); err != nil {
return err
}
}
if err := create(r.ipt4, "nat", "ts-postrouting"); err != nil {
return err
}
if r.v6NATAvailable {
if err := create(r.ipt6, "nat", "ts-postrouting"); err != nil {
return err
}
}
return nil
}
// addNetfilterBase adds some basic processing rules to be
// supplemented by later calls to other helpers.
func (r *linuxRouter) addNetfilterBase() error {
if err := r.addNetfilterBase4(); err != nil {
return err
}
if r.v6Available {
if err := r.addNetfilterBase6(); err != nil {
return err
}
}
return nil
}
// addNetfilterBase4 adds some basic IPv4 processing rules to be
// supplemented by later calls to other helpers.
func (r *linuxRouter) addNetfilterBase4() error {
// Only allow CGNAT range traffic to come from tailscale0. There
// is an exception carved out for ranges used by ChromeOS, for
// which we fall out of the Tailscale chain.
//
// Note, this will definitely break nodes that end up using the
// CGNAT range for other purposes :(.
args := []string{"!", "-i", r.tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"}
if err := r.ipt4.Append("filter", "ts-input", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err)
}
args = []string{"!", "-i", r.tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}
if err := r.ipt4.Append("filter", "ts-input", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err)
}
// Forward all traffic from the Tailscale interface, and drop
// traffic to the tailscale interface by default. We use packet
// marks here so both filter/FORWARD and nat/POSTROUTING can match
// on these packets of interest.
//
// In particular, we only want to apply SNAT rules in
// nat/POSTROUTING to packets that originated from the Tailscale
// interface, but we can't match on the inbound interface in
// POSTROUTING. So instead, we match on the inbound interface in
// filter/FORWARD, and set a packet mark that nat/POSTROUTING can
// use to effectively run that same test again.
args = []string{"-i", r.tunname, "-j", "MARK", "--set-mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask}
if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
}
args = []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "ACCEPT"}
if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
}
args = []string{"-o", r.tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}
if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
}
args = []string{"-o", r.tunname, "-j", "ACCEPT"}
if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
}
return nil
}
// addNetfilterBase4 adds some basic IPv6 processing rules to be
// supplemented by later calls to other helpers.
func (r *linuxRouter) addNetfilterBase6() error {
// TODO: only allow traffic from Tailscale's ULA range to come
// from tailscale0.
args := []string{"-i", r.tunname, "-j", "MARK", "--set-mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask}
if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
}
args = []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "ACCEPT"}
if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
}
// TODO: drop forwarded traffic to tailscale0 from tailscale's ULA
// (see corresponding IPv4 CGNAT rule).
args = []string{"-o", r.tunname, "-j", "ACCEPT"}
if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
}
return nil
}
// delNetfilterChains removes the custom Tailscale chains from netfilter.
func (r *linuxRouter) delNetfilterChains() error {
del := func(ipt netfilterRunner, table, chain string) error {
if err := ipt.ClearChain(table, chain); err != nil {
if errCode(err) == 1 {
// nonexistent chain. That's fine, since it's
// the desired state anyway.
return nil
}
return fmt.Errorf("flushing %s/%s: %w", table, chain, err)
}
if err := ipt.DeleteChain(table, chain); err != nil {
// this shouldn't fail, because if the chain didn't
// exist, we would have returned after ClearChain.
return fmt.Errorf("deleting %s/%s: %v", table, chain, err)
}
return nil
}
for _, ipt := range r.netfilterFamilies() {
if err := del(ipt, "filter", "ts-input"); err != nil {
return err
}
if err := del(ipt, "filter", "ts-forward"); err != nil {
return err
}
}
if err := del(r.ipt4, "nat", "ts-postrouting"); err != nil {
return err
}
if r.v6NATAvailable {
if err := del(r.ipt6, "nat", "ts-postrouting"); err != nil {
return err
}
}
return nil
}
// delNetfilterBase empties but does not remove custom Tailscale chains from
// netfilter.
func (r *linuxRouter) delNetfilterBase() error {
del := func(ipt netfilterRunner, table, chain string) error {
if err := ipt.ClearChain(table, chain); err != nil {
if errCode(err) == 1 {
// nonexistent chain. That's fine, since it's
// the desired state anyway.
return nil
}
return fmt.Errorf("flushing %s/%s: %w", table, chain, err)
}
return nil
}
for _, ipt := range r.netfilterFamilies() {
if err := del(ipt, "filter", "ts-input"); err != nil {
return err
}
if err := del(ipt, "filter", "ts-forward"); err != nil {
return err
}
}
if err := del(r.ipt4, "nat", "ts-postrouting"); err != nil {
return err
}
if r.v6NATAvailable {
if err := del(r.ipt6, "nat", "ts-postrouting"); err != nil {
return err
}
}
return nil
}
// addNetfilterHooks inserts calls to tailscale's netfilter chains in
// the relevant main netfilter chains. The tailscale chains must
// already exist.
func (r *linuxRouter) addNetfilterHooks() error {
divert := func(ipt netfilterRunner, table, chain string) error {
tsChain := tsChain(chain)
args := []string{"-j", tsChain}
exists, err := ipt.Exists(table, chain, args...)
if err != nil {
return fmt.Errorf("checking for %v in %s/%s: %w", args, table, chain, err)
}
if exists {
return nil
}
if err := ipt.Insert(table, chain, 1, args...); err != nil {
return fmt.Errorf("adding %v in %s/%s: %w", args, table, chain, err)
}
return nil
}
for _, ipt := range r.netfilterFamilies() {
if err := divert(ipt, "filter", "INPUT"); err != nil {
return err
}
if err := divert(ipt, "filter", "FORWARD"); err != nil {
return err
}
}
if err := divert(r.ipt4, "nat", "POSTROUTING"); err != nil {
return err
}
if r.v6NATAvailable {
if err := divert(r.ipt6, "nat", "POSTROUTING"); err != nil {
return err
}
}
return nil
}
// delNetfilterHooks deletes the calls to tailscale's netfilter chains
// in the relevant main netfilter chains.
func (r *linuxRouter) delNetfilterHooks() error {
del := func(ipt netfilterRunner, table, chain string) error {
tsChain := tsChain(chain)
args := []string{"-j", tsChain}
if err := ipt.Delete(table, chain, args...); err != nil {
// TODO(apenwarr): check for errCode(1) here.
// Unfortunately the error code from the iptables
// module resists unwrapping, unlike with other
// calls. So we have to assume if Delete fails,
// it's because there is no such rule.
r.logf("note: deleting %v in %s/%s: %w", args, table, chain, err)
return nil
}
return nil
}
for _, ipt := range r.netfilterFamilies() {
if err := del(ipt, "filter", "INPUT"); err != nil {
return err
}
if err := del(ipt, "filter", "FORWARD"); err != nil {
return err
}
}
if err := del(r.ipt4, "nat", "POSTROUTING"); err != nil {
return err
}
if r.v6NATAvailable {
if err := del(r.ipt6, "nat", "POSTROUTING"); err != nil {
return err
}
}
return nil
}
// addSNATRule adds a netfilter rule to SNAT traffic destined for
// local subnets.
func (r *linuxRouter) addSNATRule() error {
@ -1524,14 +1189,8 @@ func (r *linuxRouter) addSNATRule() error {
return nil
}
args := []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "MASQUERADE"}
if err := r.ipt4.Append("nat", "ts-postrouting", args...); err != nil {
return fmt.Errorf("adding %v in v4/nat/ts-postrouting: %w", args, err)
}
if r.v6NATAvailable {
if err := r.ipt6.Append("nat", "ts-postrouting", args...); err != nil {
return fmt.Errorf("adding %v in v6/nat/ts-postrouting: %w", args, err)
}
if err := r.nfr.AddSNATRule(); err != nil {
return err
}
return nil
}
@ -1543,14 +1202,8 @@ func (r *linuxRouter) delSNATRule() error {
return nil
}
args := []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "MASQUERADE"}
if err := r.ipt4.Delete("nat", "ts-postrouting", args...); err != nil {
return fmt.Errorf("deleting %v in v4/nat/ts-postrouting: %w", args, err)
}
if r.v6NATAvailable {
if err := r.ipt6.Delete("nat", "ts-postrouting", args...); err != nil {
return fmt.Errorf("deleting %v in v6/nat/ts-postrouting: %w", args, err)
}
if err := r.nfr.DelSNATRule(); err != nil {
return err
}
return nil
}
@ -1619,12 +1272,6 @@ func cidrDiff(kind string, old map[netip.Prefix]bool, new []netip.Prefix, add, d
return ret, nil
}
// tsChain returns the name of the tailscale sub-chain corresponding
// to the given "parent" chain (e.g. INPUT, FORWARD, ...).
func tsChain(chain string) string {
return "ts-" + strings.ToLower(chain)
}
// normalizeCIDR returns cidr as an ip/mask string, with the host bits
// of the IP address zeroed out.
func normalizeCIDR(cidr netip.Prefix) string {
@ -1632,105 +1279,9 @@ func normalizeCIDR(cidr netip.Prefix) string {
}
func cleanup(logf logger.Logf, interfaceName string) {
// TODO(dmytro): clean up iptables.
}
// checkIPv6 checks whether the system appears to have a working IPv6
// network stack. It returns an error explaining what looks wrong or
// missing. It does not check that IPv6 is currently functional or
// that there's a global address, just that the system would support
// IPv6 if it were on an IPv6 network.
func checkIPv6(logf logger.Logf) error {
_, err := os.Stat("/proc/sys/net/ipv6")
if os.IsNotExist(err) {
return err
if interfaceName != "userspace-networking" {
linuxfw.IPTablesCleanup(logf)
}
bs, err := os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_ipv6")
if err != nil {
// Be conservative if we can't find the IPv6 configuration knob.
return err
}
disabled, err := strconv.ParseBool(strings.TrimSpace(string(bs)))
if err != nil {
return errors.New("disable_ipv6 has invalid bool")
}
if disabled {
return errors.New("disable_ipv6 is set")
}
// Older kernels don't support IPv6 policy routing. Some kernels
// support policy routing but don't have this knob, so absence of
// the knob is not fatal.
bs, err = os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_policy")
if err == nil {
disabled, err = strconv.ParseBool(strings.TrimSpace(string(bs)))
if err != nil {
return errors.New("disable_policy has invalid bool")
}
if disabled {
return errors.New("disable_policy is set")
}
}
if err := checkIPRuleSupportsV6(logf); err != nil {
return fmt.Errorf("kernel doesn't support IPv6 policy routing: %w", err)
}
// Some distros ship ip6tables separately from iptables.
if _, err := exec.LookPath("ip6tables"); err != nil {
return err
}
return nil
}
// supportsV6NAT returns whether the system has a "nat" table in the
// IPv6 netfilter stack.
//
// The nat table was added after the initial release of ipv6
// netfilter, so some older distros ship a kernel that can't NAT IPv6
// traffic.
func supportsV6NAT() bool {
bs, err := os.ReadFile("/proc/net/ip6_tables_names")
if err != nil {
// Can't read the file. Assume SNAT works.
return true
}
if bytes.Contains(bs, []byte("nat\n")) {
return true
}
// In nftables mode, that proc file will be empty. Try another thing:
if exec.Command("modprobe", "ip6table_nat").Run() == nil {
return true
}
return false
}
func checkIPRuleSupportsV6(logf logger.Logf) error {
// First try just a read-only operation to ideally avoid
// having to modify any state.
if rules, err := netlink.RuleList(netlink.FAMILY_V6); err != nil {
return fmt.Errorf("querying IPv6 policy routing rules: %w", err)
} else {
if len(rules) > 0 {
logf("[v1] kernel supports IPv6 policy routing (found %d rules)", len(rules))
return nil
}
}
// Try to actually create & delete one as a test.
rule := netlink.NewRule()
rule.Priority = 1234
rule.Mark = tailscaleBypassMarkNum
rule.Table = tailscaleRouteTable.num
rule.Family = netlink.FAMILY_V6
// First delete the rule unconditionally, and don't check for
// errors. This is just cleaning up anything that might be already
// there.
netlink.RuleDel(rule)
// And clean up on exit.
defer netlink.RuleDel(rule)
return netlink.RuleAdd(rule)
}
// Checks if the running openWRT system is using mwan3, based on the heuristic

View File

@ -22,8 +22,10 @@ import (
"github.com/vishvananda/netlink"
"golang.org/x/exp/slices"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
"tailscale.com/tstest"
"tailscale.com/types/logger"
"tailscale.com/util/linuxfw"
)
func TestRouterStates(t *testing.T) {
@ -328,7 +330,7 @@ ip route add throw 192.168.0.0/24 table 52` + basic,
defer mon.Close()
fake := NewFakeOS(t)
router, err := newUserspaceRouterAdvanced(t.Logf, "tailscale0", mon, fake.netfilter4, fake.netfilter6, fake, true, true)
router, err := newUserspaceRouterAdvanced(t.Logf, "tailscale0", mon, fake.nfr, fake)
if err != nil {
t.Fatalf("failed to create router: %v", err)
}
@ -362,15 +364,25 @@ ip route add throw 192.168.0.0/24 table 52` + basic,
}
}
type fakeNetfilter struct {
t *testing.T
n map[string][]string
type fakeIPTablesRunner struct {
t *testing.T
ipt4 map[string][]string
ipt6 map[string][]string
//we always assume ipv6 and ipv6 nat are enabled when testing
}
func newNetfilter(t *testing.T) *fakeNetfilter {
return &fakeNetfilter{
func newIPTablesRunner(t *testing.T) netfilterRunner {
return &fakeIPTablesRunner{
t: t,
n: map[string][]string{
ipt4: map[string][]string{
"filter/INPUT": nil,
"filter/OUTPUT": nil,
"filter/FORWARD": nil,
"nat/PREROUTING": nil,
"nat/OUTPUT": nil,
"nat/POSTROUTING": nil,
},
ipt6: map[string][]string{
"filter/INPUT": nil,
"filter/OUTPUT": nil,
"filter/FORWARD": nil,
@ -381,115 +393,222 @@ func newNetfilter(t *testing.T) *fakeNetfilter {
}
}
func (n *fakeNetfilter) Insert(table, chain string, pos int, args ...string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
if pos > len(rules)+1 {
n.t.Errorf("bad position %d in %s", pos, k)
return errExec
func insertRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, newRule string) error {
// Get current rules for filter/ts-input chain with according IP version
curTSInputRules, ok := curIPT[chain]
if !ok {
n.t.Fatalf("no %s chain exists", chain)
return fmt.Errorf("no %s chain exists", chain)
}
// Add new rule to top of filter/ts-input
curTSInputRules = append(curTSInputRules, "")
copy(curTSInputRules[1:], curTSInputRules)
curTSInputRules[0] = newRule
curIPT[chain] = curTSInputRules
return nil
}
func appendRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, newRule string) error {
// Get current rules for filter/ts-input chain with according IP version
curTSInputRules, ok := curIPT[chain]
if !ok {
n.t.Fatalf("no %s chain exists", chain)
return fmt.Errorf("no %s chain exists", chain)
}
// Add new rule to end of filter/ts-input
curTSInputRules = append(curTSInputRules, newRule)
curIPT[chain] = curTSInputRules
return nil
}
func deleteRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, delRule string) error {
// Get current rules for filter/ts-input chain with according IP version
curTSInputRules, ok := curIPT[chain]
if !ok {
n.t.Fatalf("no %s chain exists", chain)
return fmt.Errorf("no %s chain exists", chain)
}
// Remove rule from filter/ts-input
for i, rule := range curTSInputRules {
if rule == delRule {
curTSInputRules = append(curTSInputRules[:i], curTSInputRules[i+1:]...)
break
}
}
curIPT[chain] = curTSInputRules
return nil
}
func (n *fakeIPTablesRunner) AddLoopbackRule(addr netip.Addr) error {
curIPT := n.ipt4
if addr.Is6() {
curIPT = n.ipt6
}
newRule := fmt.Sprintf("-i lo -s %s -j ACCEPT", addr.String())
return insertRule(n, curIPT, "filter/ts-input", newRule)
}
func (n *fakeIPTablesRunner) AddBase(tunname string) error {
if err := n.AddBase4(tunname); err != nil {
return err
}
if n.HasIPV6() {
if err := n.AddBase6(tunname); err != nil {
return err
}
rules = append(rules, "")
copy(rules[pos:], rules[pos-1:])
rules[pos-1] = strings.Join(args, " ")
n.n[k] = rules
} else {
n.t.Errorf("unknown table/chain %s", k)
return errExec
}
return nil
}
func (n *fakeNetfilter) Append(table, chain string, args ...string) error {
k := table + "/" + chain
return n.Insert(table, chain, len(n.n[k])+1, args...)
}
func (n *fakeNetfilter) Exists(table, chain string, args ...string) (bool, error) {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
for _, rule := range rules {
if rule == strings.Join(args, " ") {
return true, nil
}
func (n *fakeIPTablesRunner) AddBase4(tunname string) error {
curIPT := n.ipt4
newRules := []struct{ chain, rule string }{
{"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.ChromeOSVMRange().String())},
{"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())},
{"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
{"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
{"filter/ts-forward", fmt.Sprintf("-o %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())},
{"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)},
}
for _, rule := range newRules {
if err := appendRule(n, curIPT, rule.chain, rule.rule); err != nil {
return fmt.Errorf("add rule %q to chain %q: %w", rule.rule, rule.chain, err)
}
return false, nil
} else {
n.t.Errorf("unknown table/chain %s", k)
return false, errExec
}
}
func (n *fakeNetfilter) Delete(table, chain string, args ...string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
for i, rule := range rules {
if rule == strings.Join(args, " ") {
rules = append(rules[:i], rules[i+1:]...)
n.n[k] = rules
return nil
}
}
n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k)
return errExec
} else {
n.t.Errorf("unknown table/chain %s", k)
return errExec
}
}
func (n *fakeNetfilter) ClearChain(table, chain string) error {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
n.n[k] = nil
return nil
} else {
n.t.Logf("note: ClearChain: unknown table/chain %s", k)
return errors.New("exitcode:1")
}
}
func (n *fakeNetfilter) NewChain(table, chain string) error {
k := table + "/" + chain
if _, ok := n.n[k]; ok {
n.t.Errorf("table/chain %s already exists", k)
return errExec
}
n.n[k] = nil
return nil
}
func (n *fakeNetfilter) DeleteChain(table, chain string) error {
k := table + "/" + chain
if rules, ok := n.n[k]; ok {
if len(rules) != 0 {
n.t.Errorf("%s is not empty", k)
return errExec
}
delete(n.n, k)
return nil
} else {
n.t.Errorf("%s does not exist", k)
return errExec
func (n *fakeIPTablesRunner) AddBase6(tunname string) error {
curIPT := n.ipt6
newRules := []struct{ chain, rule string }{
{"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
{"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
{"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)},
}
for _, rule := range newRules {
if err := appendRule(n, curIPT, rule.chain, rule.rule); err != nil {
return fmt.Errorf("add rule %q to chain %q: %w", rule.rule, rule.chain, err)
}
}
return nil
}
func (n *fakeIPTablesRunner) DelLoopbackRule(addr netip.Addr) error {
curIPT := n.ipt4
if addr.Is6() {
curIPT = n.ipt6
}
delRule := fmt.Sprintf("-i lo -s %s -j ACCEPT", addr.String())
return deleteRule(n, curIPT, "filter/ts-input", delRule)
}
func (n *fakeIPTablesRunner) AddHooks() error {
newRules := []struct{ chain, rule string }{
{"filter/INPUT", "-j ts-input"},
{"filter/FORWARD", "-j ts-forward"},
{"nat/POSTROUTING", "-j ts-postrouting"},
}
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
for _, r := range newRules {
if err := insertRule(n, ipt, r.chain, r.rule); err != nil {
return err
}
}
}
return nil
}
func (n *fakeIPTablesRunner) DelHooks(logf logger.Logf) error {
delRules := []struct{ chain, rule string }{
{"filter/INPUT", "-j ts-input"},
{"filter/FORWARD", "-j ts-forward"},
{"nat/POSTROUTING", "-j ts-postrouting"},
}
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
for _, r := range delRules {
if err := deleteRule(n, ipt, r.chain, r.rule); err != nil {
return err
}
}
}
return nil
}
func (n *fakeIPTablesRunner) AddChains() error {
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
for _, chain := range []string{"filter/ts-input", "filter/ts-forward", "nat/ts-postrouting"} {
ipt[chain] = nil
}
}
return nil
}
func (n *fakeIPTablesRunner) DelChains() error {
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
for chain := range ipt {
if strings.HasPrefix(chain, "filter/ts-") || strings.HasPrefix(chain, "nat/ts-") {
delete(ipt, chain)
}
}
}
return nil
}
func (n *fakeIPTablesRunner) DelBase() error {
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
for _, chain := range []string{"filter/ts-input", "filter/ts-forward", "nat/ts-postrouting"} {
ipt[chain] = nil
}
}
return nil
}
func (n *fakeIPTablesRunner) AddSNATRule() error {
newRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
if err := appendRule(n, ipt, "nat/ts-postrouting", newRule); err != nil {
return err
}
}
return nil
}
func (n *fakeIPTablesRunner) DelSNATRule() error {
delRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)
for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
if err := deleteRule(n, ipt, "nat/ts-postrouting", delRule); err != nil {
return err
}
}
return nil
}
func (n *fakeIPTablesRunner) HasIPV6() bool { return true }
func (n *fakeIPTablesRunner) HasIPV6NAT() bool { return true }
// fakeOS implements commandRunner and provides v4 and v6
// netfilterRunners, but captures changes without touching the OS.
type fakeOS struct {
t *testing.T
up bool
ips []string
routes []string
rules []string
netfilter4 *fakeNetfilter
netfilter6 *fakeNetfilter
t *testing.T
up bool
ips []string
routes []string
rules []string
//This test tests on the router level, so we will not bother
//with using iptables or nftables, chose the simpler one.
nfr netfilterRunner
}
func NewFakeOS(t *testing.T) *fakeOS {
return &fakeOS{
t: t,
netfilter4: newNetfilter(t),
netfilter6: newNetfilter(t),
t: t,
nfr: newIPTablesRunner(t),
}
}
@ -516,23 +635,23 @@ func (o *fakeOS) String() string {
}
var chains []string
for chain := range o.netfilter4.n {
for chain := range o.nfr.(*fakeIPTablesRunner).ipt4 {
chains = append(chains, chain)
}
sort.Strings(chains)
for _, chain := range chains {
for _, rule := range o.netfilter4.n[chain] {
for _, rule := range o.nfr.(*fakeIPTablesRunner).ipt4[chain] {
fmt.Fprintf(&b, "v4/%s %s\n", chain, rule)
}
}
chains = nil
for chain := range o.netfilter6.n {
for chain := range o.nfr.(*fakeIPTablesRunner).ipt6 {
chains = append(chains, chain)
}
sort.Strings(chains)
for _, chain := range chains {
for _, rule := range o.netfilter6.n[chain] {
for _, rule := range o.nfr.(*fakeIPTablesRunner).ipt6[chain] {
fmt.Fprintf(&b, "v6/%s %s\n", chain, rule)
}
}
@ -806,7 +925,7 @@ func TestDebugListRules(t *testing.T) {
}
func TestCheckIPRuleSupportsV6(t *testing.T) {
err := checkIPRuleSupportsV6(t.Logf)
err := linuxfw.CheckIPRuleSupportsV6(t.Logf)
if err != nil && os.Getuid() != 0 {
t.Skipf("skipping, error when not root: %v", err)
}