diff --git a/net/packet/ip4.go b/net/packet/ip4.go index 9203ccc53..0240abaa1 100644 --- a/net/packet/ip4.go +++ b/net/packet/ip4.go @@ -6,47 +6,11 @@ package packet import ( "encoding/binary" - "fmt" + "errors" "inet.af/netaddr" ) -// IP4 is an IPv4 address. -type IP4 uint32 - -// IPFromNetaddr converts a netaddr.IP to an IP4. Panics if !ip.Is4. -func IP4FromNetaddr(ip netaddr.IP) IP4 { - ipbytes := ip.As4() - return IP4(binary.BigEndian.Uint32(ipbytes[:])) -} - -// Netaddr converts ip to a netaddr.IP. -func (ip IP4) Netaddr() netaddr.IP { - return netaddr.IPv4(byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)) -} - -func (ip IP4) String() string { - return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip)) -} - -// IsMulticast returns whether ip is a multicast address. -func (ip IP4) IsMulticast() bool { - return byte(ip>>24)&0xf0 == 0xe0 -} - -// IsLinkLocalUnicast returns whether ip is a link-local unicast -// address. -func (ip IP4) IsLinkLocalUnicast() bool { - return byte(ip>>24) == 169 && byte(ip>>16) == 254 -} - -// IsMostLinkLocalUnicast returns whether ip is a link-local unicast -// address other than the magical "169.254.169.254" address used by -// GCP DNS. -func (ip IP4) IsMostLinkLocalUnicast() bool { - return ip.IsLinkLocalUnicast() && ip != 0xA9FEA9FE -} - // ip4HeaderLength is the length of an IPv4 header with no IP options. const ip4HeaderLength = 20 @@ -54,8 +18,8 @@ const ip4HeaderLength = 20 type IP4Header struct { IPProto IPProto IPID uint16 - SrcIP IP4 - DstIP IP4 + Src netaddr.IP + Dst netaddr.IP } // Len implements Header. @@ -63,6 +27,8 @@ func (h IP4Header) Len() int { return ip4HeaderLength } +var errWrongFamily = errors.New("wrong address family for src/dst IP") + // Marshal implements Header. func (h IP4Header) Marshal(buf []byte) error { if len(buf) < h.Len() { @@ -71,6 +37,9 @@ func (h IP4Header) Marshal(buf []byte) error { if len(buf) > maxPacketLength { return errLargePacket } + if !h.Src.Is4() || !h.Dst.Is4() { + return errWrongFamily + } buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL buf[1] = 0x00 // DSCP + ECN @@ -83,8 +52,10 @@ func (h IP4Header) Marshal(buf []byte) error { // it later, because the checksum computation runs over these // bytes and expects them to be zero. binary.BigEndian.PutUint16(buf[10:12], 0) - binary.BigEndian.PutUint32(buf[12:16], uint32(h.SrcIP)) // Src - binary.BigEndian.PutUint32(buf[16:20], uint32(h.DstIP)) // Dst + src := h.Src.As4() + dst := h.Dst.As4() + copy(buf[12:16], src[:]) + copy(buf[16:20], dst[:]) binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum @@ -93,7 +64,7 @@ func (h IP4Header) Marshal(buf []byte) error { // ToResponse implements Header. func (h *IP4Header) ToResponse() { - h.SrcIP, h.DstIP = h.DstIP, h.SrcIP + h.Src, h.Dst = h.Dst, h.Src // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. h.IPID = ^h.IPID } @@ -135,8 +106,9 @@ func (h IP4Header) marshalPseudo(buf []byte) error { } length := len(buf) - h.Len() - binary.BigEndian.PutUint32(buf[8:12], uint32(h.SrcIP)) - binary.BigEndian.PutUint32(buf[12:16], uint32(h.DstIP)) + src, dst := h.Src.As4(), h.Dst.As4() + copy(buf[8:12], src[:]) + copy(buf[12:16], dst[:]) buf[16] = 0x0 buf[17] = uint8(h.IPProto) binary.BigEndian.PutUint16(buf[18:20], uint16(length)) diff --git a/net/packet/ip6.go b/net/packet/ip6.go index 8fd964c21..59f605b32 100644 --- a/net/packet/ip6.go +++ b/net/packet/ip6.go @@ -6,49 +6,10 @@ package packet import ( "encoding/binary" - "fmt" "inet.af/netaddr" ) -// IP6 is an IPv6 address. -type IP6 struct { - Hi, Lo uint64 -} - -// IP6FromRaw16 converts a raw 16-byte IPv6 address to an IP6. -func IP6FromRaw16(ip [16]byte) IP6 { - return IP6{binary.BigEndian.Uint64(ip[:8]), binary.BigEndian.Uint64(ip[8:])} -} - -// IP6FromNetaddr converts a netaddr.IP to an IP6. Panics if !ip.Is6. -func IP6FromNetaddr(ip netaddr.IP) IP6 { - if !ip.Is6() { - panic(fmt.Sprintf("IP6FromNetaddr called with non-v6 addr %q", ip)) - } - return IP6FromRaw16(ip.As16()) -} - -// Netaddr converts ip to a netaddr.IP. -func (ip IP6) Netaddr() netaddr.IP { - var b [16]byte - binary.BigEndian.PutUint64(b[:8], ip.Hi) - binary.BigEndian.PutUint64(b[8:], ip.Lo) - return netaddr.IPFrom16(b) -} - -func (ip IP6) String() string { - return ip.Netaddr().String() -} - -func (ip IP6) IsMulticast() bool { - return (ip.Hi >> 56) == 0xFF -} - -func (ip IP6) IsLinkLocalUnicast() bool { - return (ip.Hi >> 48) == 0xFE80 -} - // ip6HeaderLength is the length of an IPv6 header with no IP options. const ip6HeaderLength = 40 @@ -56,8 +17,8 @@ const ip6HeaderLength = 40 type IP6Header struct { IPProto IPProto IPID uint32 // only lower 20 bits used - SrcIP IP6 - DstIP IP6 + Src netaddr.IP + Dst netaddr.IP } // Len implements Header. @@ -79,17 +40,16 @@ func (h IP6Header) Marshal(buf []byte) error { binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length buf[6] = uint8(h.IPProto) // Inner protocol buf[7] = 64 // TTL - binary.BigEndian.PutUint64(buf[8:16], h.SrcIP.Hi) - binary.BigEndian.PutUint64(buf[16:24], h.SrcIP.Lo) - binary.BigEndian.PutUint64(buf[24:32], h.DstIP.Hi) - binary.BigEndian.PutUint64(buf[32:40], h.DstIP.Lo) + src, dst := h.Src.As16(), h.Dst.As16() + copy(buf[8:24], src[:]) + copy(buf[24:40], dst[:]) return nil } // ToResponse implements Header. func (h *IP6Header) ToResponse() { - h.SrcIP, h.DstIP = h.DstIP, h.SrcIP + h.Src, h.Dst = h.Dst, h.Src // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. h.IPID = (^h.IPID) & 0x000FFFFF } @@ -104,10 +64,9 @@ func (h IP6Header) marshalPseudo(buf []byte) error { return errLargePacket } - binary.BigEndian.PutUint64(buf[:8], h.SrcIP.Hi) - binary.BigEndian.PutUint64(buf[8:16], h.SrcIP.Lo) - binary.BigEndian.PutUint64(buf[16:24], h.DstIP.Hi) - binary.BigEndian.PutUint64(buf[24:32], h.DstIP.Lo) + src, dst := h.Src.As16(), h.Dst.As16() + copy(buf[:16], src[:]) + copy(buf[16:32], dst[:]) binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) buf[36] = 0 buf[37] = 0 diff --git a/net/packet/packet.go b/net/packet/packet.go index f4cb1188c..5502d1959 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -7,8 +7,10 @@ package packet import ( "encoding/binary" "fmt" + "net" "strings" + "inet.af/netaddr" "tailscale.com/types/strbuilder" ) @@ -38,64 +40,50 @@ type Parsed struct { IPVersion uint8 // IPProto is the IP subprotocol (UDP, TCP, etc.). Valid iff IPVersion != 0. IPProto IPProto - // SrcIP4 is the IPv4 source address. Valid iff IPVersion == 4. - SrcIP4 IP4 - // DstIP4 is the IPv4 destination address. Valid iff IPVersion == 4. - DstIP4 IP4 - // SrcIP6 is the IPv6 source address. Valid iff IPVersion == 6. - SrcIP6 IP6 - // DstIP6 is the IPv6 destination address. Valid iff IPVersion == 6. - DstIP6 IP6 - // SrcPort is the TCP/UDP source port. Valid iff IPProto == TCP || IPProto == UDP. - SrcPort uint16 - // DstPort is the TCP/UDP source port. Valid iff IPProto == TCP || IPProto == UDP. - DstPort uint16 + // SrcIP4 is the source address. Family matches IPVersion. Port is + // valid iff IPProto == TCP || IPProto == UDP. + Src netaddr.IPPort + // DstIP4 is the destination address. Family matches IPVersion. + Dst netaddr.IPPort // TCPFlags is the packet's TCP flag bigs. Valid iff IPProto == TCP. TCPFlags uint8 } func (p *Parsed) String() string { - switch p.IPVersion { - case 4: - sb := strbuilder.Get() - sb.WriteString(p.IPProto.String()) - sb.WriteByte('{') - writeIP4Port(sb, p.SrcIP4, p.SrcPort) - sb.WriteString(" > ") - writeIP4Port(sb, p.DstIP4, p.DstPort) - sb.WriteByte('}') - return sb.String() - case 6: - sb := strbuilder.Get() - sb.WriteString(p.IPProto.String()) - sb.WriteByte('{') - writeIP6Port(sb, p.SrcIP6, p.SrcPort) - sb.WriteString(" > ") - writeIP6Port(sb, p.DstIP6, p.DstPort) - sb.WriteByte('}') - return sb.String() - default: + if p.IPVersion != 4 && p.IPVersion != 6 { return "Unknown{???}" } + + sb := strbuilder.Get() + sb.WriteString(p.IPProto.String()) + sb.WriteByte('{') + writeIPPort(sb, p.Src) + sb.WriteString(" > ") + writeIPPort(sb, p.Dst) + sb.WriteByte('}') + return sb.String() } -func writeIP4Port(sb *strbuilder.Builder, ip IP4, port uint16) { - sb.WriteUint(uint64(byte(ip >> 24))) - sb.WriteByte('.') - sb.WriteUint(uint64(byte(ip >> 16))) - sb.WriteByte('.') - sb.WriteUint(uint64(byte(ip >> 8))) - sb.WriteByte('.') - sb.WriteUint(uint64(byte(ip))) - sb.WriteByte(':') - sb.WriteUint(uint64(port)) -} - -func writeIP6Port(sb *strbuilder.Builder, ip IP6, port uint16) { - sb.WriteByte('[') - sb.WriteString(ip.Netaddr().String()) // TODO: faster? - sb.WriteString("]:") - sb.WriteUint(uint64(port)) +// writeIPPort writes ipp.String() into sb, with fewer allocations. +// +// TODO: make netaddr more efficient in this area, and retire this func. +func writeIPPort(sb *strbuilder.Builder, ipp netaddr.IPPort) { + if ipp.IP.Is4() { + raw := ipp.IP.As4() + sb.WriteUint(uint64(raw[0])) + sb.WriteByte('.') + sb.WriteUint(uint64(raw[1])) + sb.WriteByte('.') + sb.WriteUint(uint64(raw[2])) + sb.WriteByte('.') + sb.WriteUint(uint64(raw[3])) + sb.WriteByte(':') + } else { + sb.WriteByte('[') + sb.WriteString(ipp.IP.String()) // TODO: faster? + sb.WriteString("]:") + } + sb.WriteUint(uint64(ipp.Port)) } // Decode extracts data from the packet in b into q. @@ -140,8 +128,8 @@ func (q *Parsed) decode4(b []byte) { } // If it's valid IPv4, then the IP addresses are valid - q.SrcIP4 = IP4(binary.BigEndian.Uint32(b[12:16])) - q.DstIP4 = IP4(binary.BigEndian.Uint32(b[16:20])) + q.Src.IP = netaddr.IPv4(b[12], b[13], b[14], b[15]) + q.Dst.IP = netaddr.IPv4(b[16], b[17], b[18], b[19]) q.subofs = int((b[0] & 0x0F) << 2) if q.subofs > q.length { @@ -183,8 +171,8 @@ func (q *Parsed) decode4(b []byte) { q.IPProto = Unknown return } - q.SrcPort = 0 - q.DstPort = 0 + q.Src.Port = 0 + q.Dst.Port = 0 q.dataofs = q.subofs + icmp4HeaderLength return case IGMP: @@ -196,8 +184,8 @@ func (q *Parsed) decode4(b []byte) { q.IPProto = Unknown return } - q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) - q.DstPort = binary.BigEndian.Uint16(sub[2:4]) + q.Src.Port = binary.BigEndian.Uint16(sub[0:2]) + q.Dst.Port = binary.BigEndian.Uint16(sub[2:4]) q.TCPFlags = sub[13] & 0x3F headerLength := (sub[12] & 0xF0) >> 2 q.dataofs = q.subofs + int(headerLength) @@ -207,8 +195,8 @@ func (q *Parsed) decode4(b []byte) { q.IPProto = Unknown return } - q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) - q.DstPort = binary.BigEndian.Uint16(sub[2:4]) + q.Src.Port = binary.BigEndian.Uint16(sub[0:2]) + q.Dst.Port = binary.BigEndian.Uint16(sub[2:4]) q.dataofs = q.subofs + udpHeaderLength return default: @@ -249,10 +237,10 @@ func (q *Parsed) decode6(b []byte) { return } - q.SrcIP6.Hi = binary.BigEndian.Uint64(b[8:16]) - q.SrcIP6.Lo = binary.BigEndian.Uint64(b[16:24]) - q.DstIP6.Hi = binary.BigEndian.Uint64(b[24:32]) - q.DstIP6.Lo = binary.BigEndian.Uint64(b[32:40]) + // okay to ignore `ok` here, because IPs pulled from packets are + // always well-formed stdlib IPs. + q.Src.IP, _ = netaddr.FromStdIP(net.IP(b[8:24])) + q.Dst.IP, _ = netaddr.FromStdIP(net.IP(b[24:40])) // We don't support any IPv6 extension headers. Don't try to // be clever. Therefore, the IP subprotocol always starts at @@ -276,16 +264,16 @@ func (q *Parsed) decode6(b []byte) { q.IPProto = Unknown return } - q.SrcPort = 0 - q.DstPort = 0 + q.Src.Port = 0 + q.Dst.Port = 0 q.dataofs = q.subofs + icmp6HeaderLength case TCP: if len(sub) < tcpHeaderLength { q.IPProto = Unknown return } - q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) - q.DstPort = binary.BigEndian.Uint16(sub[2:4]) + q.Src.Port = binary.BigEndian.Uint16(sub[0:2]) + q.Dst.Port = binary.BigEndian.Uint16(sub[2:4]) q.TCPFlags = sub[13] & 0x3F headerLength := (sub[12] & 0xF0) >> 2 q.dataofs = q.subofs + int(headerLength) @@ -295,8 +283,8 @@ func (q *Parsed) decode6(b []byte) { q.IPProto = Unknown return } - q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) - q.DstPort = binary.BigEndian.Uint16(sub[2:4]) + q.Src.Port = binary.BigEndian.Uint16(sub[0:2]) + q.Dst.Port = binary.BigEndian.Uint16(sub[2:4]) q.dataofs = q.subofs + udpHeaderLength default: q.IPProto = Unknown @@ -312,8 +300,8 @@ func (q *Parsed) IP4Header() IP4Header { return IP4Header{ IPID: ipid, IPProto: q.IPProto, - SrcIP: q.SrcIP4, - DstIP: q.DstIP4, + Src: q.Src.IP, + Dst: q.Dst.IP, } } @@ -334,8 +322,8 @@ func (q *Parsed) UDP4Header() UDP4Header { } return UDP4Header{ IP4Header: q.IP4Header(), - SrcPort: q.SrcPort, - DstPort: q.DstPort, + SrcPort: q.Src.Port, + DstPort: q.Dst.Port, } } diff --git a/net/packet/packet_test.go b/net/packet/packet_test.go index b92c26678..21615e519 100644 --- a/net/packet/packet_test.go +++ b/net/packet/packet_test.go @@ -12,54 +12,12 @@ import ( "inet.af/netaddr" ) -func mustIP4(s string) IP4 { - ip, err := netaddr.ParseIP(s) +func mustIPPort(s string) netaddr.IPPort { + ipp, err := netaddr.ParseIPPort(s) if err != nil { panic(err) } - return IP4FromNetaddr(ip) -} - -func mustIP6(s string) IP6 { - ip, err := netaddr.ParseIP(s) - if err != nil { - panic(err) - } - return IP6FromNetaddr(ip) -} - -func TestIP4String(t *testing.T) { - const str = "1.2.3.4" - ip := mustIP4(str) - - var got string - allocs := testing.AllocsPerRun(1000, func() { - got = ip.String() - }) - - if got != str { - t.Errorf("got %q; want %q", got, str) - } - if allocs != 1 { - t.Errorf("allocs = %v; want 1", allocs) - } -} - -func TestIP6String(t *testing.T) { - const str = "2607:f8b0:400a:809::200e" - ip := mustIP6(str) - - var got string - allocs := testing.AllocsPerRun(1000, func() { - got = ip.String() - }) - - if got != str { - t.Errorf("got %q; want %q", got, str) - } - if allocs != 1 { - t.Errorf("allocs = %v; want 1", allocs) - } + return ipp } var icmp4RequestBuffer = []byte{ @@ -83,10 +41,8 @@ var icmp4RequestDecode = Parsed{ IPVersion: 4, IPProto: ICMPv4, - SrcIP4: mustIP4("1.2.3.4"), - DstIP4: mustIP4("5.6.7.8"), - SrcPort: 0, - DstPort: 0, + Src: mustIPPort("1.2.3.4:0"), + Dst: mustIPPort("5.6.7.8:0"), } var icmp4ReplyBuffer = []byte{ @@ -109,10 +65,8 @@ var icmp4ReplyDecode = Parsed{ IPVersion: 4, IPProto: ICMPv4, - SrcIP4: mustIP4("1.2.3.4"), - DstIP4: mustIP4("5.6.7.8"), - SrcPort: 0, - DstPort: 0, + Src: mustIPPort("1.2.3.4:0"), + Dst: mustIPPort("5.6.7.8:0"), } // ICMPv6 Router Solicitation @@ -132,8 +86,8 @@ var icmp6PacketDecode = Parsed{ length: len(icmp6PacketBuffer), IPVersion: 6, IPProto: ICMPv6, - SrcIP6: mustIP6("fe80::fb57:1dea:9c39:8fb7"), - DstIP6: mustIP6("ff02::2"), + Src: mustIPPort("[fe80::fb57:1dea:9c39:8fb7]:0"), + Dst: mustIPPort("[ff02::2]:0"), } // This is a malformed IPv4 packet. @@ -170,10 +124,8 @@ var tcp4PacketDecode = Parsed{ IPVersion: 4, IPProto: TCP, - SrcIP4: mustIP4("1.2.3.4"), - DstIP4: mustIP4("5.6.7.8"), - SrcPort: 123, - DstPort: 567, + Src: mustIPPort("1.2.3.4:123"), + Dst: mustIPPort("5.6.7.8:567"), TCPFlags: TCPSynAck, } @@ -198,10 +150,8 @@ var tcp6RequestDecode = Parsed{ IPVersion: 6, IPProto: TCP, - SrcIP6: mustIP6("2001:559:bc13:5400:1749:4628:3934:e1b"), - DstIP6: mustIP6("2607:f8b0:400a:809::200e"), - SrcPort: 42080, - DstPort: 80, + Src: mustIPPort("[2001:559:bc13:5400:1749:4628:3934:e1b]:42080"), + Dst: mustIPPort("[2607:f8b0:400a:809::200e]:80"), TCPFlags: TCPSyn, } @@ -226,10 +176,8 @@ var udp4RequestDecode = Parsed{ IPVersion: 4, IPProto: UDP, - SrcIP4: mustIP4("1.2.3.4"), - DstIP4: mustIP4("5.6.7.8"), - SrcPort: 123, - DstPort: 567, + Src: mustIPPort("1.2.3.4:123"), + Dst: mustIPPort("5.6.7.8:567"), } var invalid4RequestBuffer = []byte{ @@ -250,8 +198,8 @@ var invalid4RequestDecode = Parsed{ IPVersion: 4, IPProto: Unknown, - SrcIP4: mustIP4("1.2.3.4"), - DstIP4: mustIP4("5.6.7.8"), + Src: mustIPPort("1.2.3.4:0"), + Dst: mustIPPort("5.6.7.8:0"), } var udp6RequestBuffer = []byte{ @@ -275,10 +223,8 @@ var udp6RequestDecode = Parsed{ IPVersion: 6, IPProto: UDP, - SrcIP6: mustIP6("2001:559:bc13:5400:1749:4628:3934:e1b"), - DstIP6: mustIP6("2607:f8b0:400a:809::200e"), - SrcPort: 54276, - DstPort: 443, + Src: mustIPPort("[2001:559:bc13:5400:1749:4628:3934:e1b]:54276"), + Dst: mustIPPort("[2607:f8b0:400a:809::200e]:443"), } var udp4ReplyBuffer = []byte{ @@ -301,10 +247,8 @@ var udp4ReplyDecode = Parsed{ length: len(udp4ReplyBuffer), IPProto: UDP, - SrcIP4: mustIP4("1.2.3.4"), - DstIP4: mustIP4("5.6.7.8"), - SrcPort: 567, - DstPort: 123, + Src: mustIPPort("1.2.3.4:567"), + Dst: mustIPPort("5.6.7.8:123"), } var igmpPacketBuffer = []byte{ @@ -326,8 +270,8 @@ var igmpPacketDecode = Parsed{ IPVersion: 4, IPProto: IGMP, - SrcIP4: mustIP4("192.168.1.82"), - DstIP4: mustIP4("224.0.0.251"), + Src: mustIPPort("192.168.1.82:0"), + Dst: mustIPPort("224.0.0.251:0"), } func TestParsed(t *testing.T) { diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index d363c8b3b..f35578e15 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -25,45 +25,33 @@ type Filter struct { // tailscale must have a destination within local4 or local6, // regardless of the policy filter below. Zero values reject // all incoming traffic. - local4 []net4 - local6 []net6 + local4 []netaddr.IPPrefix + local6 []netaddr.IPPrefix // matches4 and matches6 are lists of match->action rules // applied to all packets arriving over tailscale // tunnels. Matches are checked in order, and processing stops // at the first matching rule. The default policy if no rules // match is to drop the packet. - matches4 matches4 - matches6 matches6 + matches4 matches + matches6 matches // state is the connection tracking state attached to this // filter. It is used to allow incoming traffic that is a response // to an outbound connection that this node made, even if those // incoming packets don't get accepted by matches above. - state4 *filterState - state6 *filterState + state *filterState } -// tuple4 is a 4-tuple of source and destination IPv4 and port. It's -// used as a lookup key in filterState. -type tuple4 struct { - SrcIP packet.IP4 - DstIP packet.IP4 - SrcPort uint16 - DstPort uint16 -} - -// tuple6 is a 4-tuple of source and destination IPv6 and port. It's -// used as a lookup key in filterState. -type tuple6 struct { - SrcIP packet.IP6 - DstIP packet.IP6 - SrcPort uint16 - DstPort uint16 +// tuple is a 4-tuple of source and destination IP and port. It's used +// as a lookup key in filterState. +type tuple struct { + Src netaddr.IPPort + Dst netaddr.IPPort } // filterState is a state cache of past seen packets. type filterState struct { mu sync.Mutex - lru *lru.Cache // of tuple4 or tuple6 + lru *lru.Cache // of tuple } // lruMax is the size of the LRU cache in filterState. @@ -148,30 +136,58 @@ func NewAllowNone(logf logger.Logf) *Filter { // shares state with the previous one, to enable changing rules at // runtime without breaking existing stateful flows. func New(matches []Match, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter { - var state4, state6 *filterState + var state *filterState if shareStateWith != nil { - state4 = shareStateWith.state4 - state6 = shareStateWith.state6 + state = shareStateWith.state } else { - state4 = &filterState{ - lru: lru.New(lruMax), - } - state6 = &filterState{ + state = &filterState{ lru: lru.New(lruMax), } } f := &Filter{ logf: logf, - matches4: newMatches4(matches), - matches6: newMatches6(matches), - local4: nets4FromIPPrefixes(localNets), - local6: nets6FromIPPrefixes(localNets), - state4: state4, - state6: state6, + matches4: matchesFamily(matches, netaddr.IP.Is4), + matches6: matchesFamily(matches, netaddr.IP.Is6), + local4: netsFamily(localNets, netaddr.IP.Is4), + local6: netsFamily(localNets, netaddr.IP.Is6), + state: state, } return f } +func netsFamily(nets []netaddr.IPPrefix, keep func(netaddr.IP) bool) []netaddr.IPPrefix { + var ret []netaddr.IPPrefix + for _, net := range nets { + if keep(net.IP) { + ret = append(ret, net) + } + } + return ret +} + +// matchesFamily returns the subset of ms for which keep(srcNet.IP) +// and keep(dstNet.IP) are both true. +func matchesFamily(ms matches, keep func(netaddr.IP) bool) matches { + var ret matches + for _, m := range ms { + var retm Match + for _, src := range m.Srcs { + if keep(src.IP) { + retm.Srcs = append(retm.Srcs, src) + } + } + for _, dst := range m.Dsts { + if keep(dst.Net.IP) { + retm.Dsts = append(retm.Dsts, dst) + } + } + if len(retm.Srcs) > 0 && len(retm.Dsts) > 0 { + ret = append(ret, retm) + } + } + return ret +} + func maybeHexdump(flag RunFlags, b []byte) string { if flag == 0 { return "" @@ -229,19 +245,17 @@ func (f *Filter) CheckTCP(srcIP, dstIP netaddr.IP, dstPort uint16) Response { return Drop case srcIP.Is4(): pkt.IPVersion = 4 - pkt.SrcIP4 = packet.IP4FromNetaddr(srcIP) - pkt.DstIP4 = packet.IP4FromNetaddr(dstIP) case srcIP.Is6(): pkt.IPVersion = 6 - pkt.SrcIP6 = packet.IP6FromNetaddr(srcIP) - pkt.DstIP6 = packet.IP6FromNetaddr(dstIP) default: panic("unreachable") } + pkt.Src.IP = srcIP + pkt.Dst.IP = dstIP pkt.IPProto = packet.TCP pkt.TCPFlags = packet.TCPSyn - pkt.SrcPort = 0 - pkt.DstPort = dstPort + pkt.Src.Port = 0 + pkt.Dst.Port = dstPort return f.RunIn(pkt, 0) } @@ -287,7 +301,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { // A compromised peer could try to send us packets for // destinations we didn't explicitly advertise. This check is to // prevent that. - if !ip4InList(q.DstIP4, f.local4) { + if !ipInList(q.Dst.IP, f.local4) { return Drop, "destination not allowed" } @@ -320,11 +334,11 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { return Accept, "tcp ok" } case packet.UDP: - t := tuple4{q.SrcIP4, q.DstIP4, q.SrcPort, q.DstPort} + t := tuple{q.Src, q.Dst} - f.state4.mu.Lock() - _, ok := f.state4.lru.Get(t) - f.state4.mu.Unlock() + f.state.mu.Lock() + _, ok := f.state.lru.Get(t) + f.state.mu.Unlock() if ok { return Accept, "udp cached" @@ -342,7 +356,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { // A compromised peer could try to send us packets for // destinations we didn't explicitly advertise. This check is to // prevent that. - if !ip6InList(q.DstIP6, f.local6) { + if !ipInList(q.Dst.IP, f.local6) { return Drop, "destination not allowed" } @@ -375,11 +389,11 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { return Accept, "tcp ok" } case packet.UDP: - t := tuple6{q.SrcIP6, q.DstIP6, q.SrcPort, q.DstPort} + t := tuple{q.Src, q.Dst} - f.state6.mu.Lock() - _, ok := f.state6.lru.Get(t) - f.state6.mu.Unlock() + f.state.mu.Lock() + _, ok := f.state.lru.Get(t) + f.state.mu.Unlock() if ok { return Accept, "udp cached" @@ -399,20 +413,11 @@ func (f *Filter) runOut(q *packet.Parsed) (r Response, why string) { return Accept, "ok out" } - switch q.IPVersion { - case 4: - t := tuple4{q.DstIP4, q.SrcIP4, q.DstPort, q.SrcPort} - var ti interface{} = t // allocate once, rather than twice inside mutex - f.state4.mu.Lock() - f.state4.lru.Add(ti, ti) - f.state4.mu.Unlock() - case 6: - t := tuple6{q.DstIP6, q.SrcIP6, q.DstPort, q.SrcPort} - var ti interface{} = t // allocate once, rather than twice inside mutex - f.state6.mu.Lock() - f.state6.lru.Add(ti, ti) - f.state6.mu.Unlock() - } + t := tuple{q.Dst, q.Src} + var ti interface{} = t // allocate once, rather than twice inside mutex + f.state.mu.Lock() + f.state.lru.Add(ti, ti) + f.state.mu.Unlock() return Accept, "ok out" } @@ -436,6 +441,8 @@ func (d direction) String() string { } } +var gcpDNSAddr = netaddr.IPv4(169, 254, 169, 254) + // pre runs the direction-agnostic filter logic. dir is only used for // logging. func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response { @@ -448,25 +455,13 @@ func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response { return Drop } - switch q.IPVersion { - case 4: - if q.DstIP4.IsMulticast() { - f.logRateLimit(rf, q, dir, Drop, "multicast") - return Drop - } - if q.DstIP4.IsMostLinkLocalUnicast() { - f.logRateLimit(rf, q, dir, Drop, "link-local-unicast") - return Drop - } - case 6: - if q.DstIP6.IsMulticast() { - f.logRateLimit(rf, q, dir, Drop, "multicast") - return Drop - } - if q.DstIP6.IsLinkLocalUnicast() { - f.logRateLimit(rf, q, dir, Drop, "link-local-unicast") - return Drop - } + if q.Dst.IP.IsMulticast() { + f.logRateLimit(rf, q, dir, Drop, "multicast") + return Drop + } + if q.Dst.IP.IsLinkLocalUnicast() && q.Dst.IP != gcpDNSAddr { + f.logRateLimit(rf, q, dir, Drop, "link-local-unicast") + return Drop } switch q.IPProto { @@ -493,12 +488,5 @@ func omitDropLogging(p *packet.Parsed, dir direction) bool { return false } - switch p.IPVersion { - case 4: - return p.DstIP4.IsMulticast() || p.DstIP4.IsMostLinkLocalUnicast() || p.IPProto == packet.IGMP - case 6: - return p.DstIP6.IsMulticast() || p.DstIP6.IsLinkLocalUnicast() - default: - return false - } + return p.Dst.IP.IsMulticast() || (p.Dst.IP.IsLinkLocalUnicast() && p.Dst.IP != gcpDNSAddr) || p.IPProto == packet.IGMP } diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index dffcf5f0b..5d0e909a5 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -94,9 +94,9 @@ func TestFilter(t *testing.T) { if test.p.IPProto == packet.TCP { var got Response if test.p.IPVersion == 4 { - got = acl.CheckTCP(test.p.SrcIP4.Netaddr(), test.p.DstIP4.Netaddr(), test.p.DstPort) + got = acl.CheckTCP(test.p.Src.IP, test.p.Dst.IP, test.p.Dst.Port) } else { - got = acl.CheckTCP(test.p.SrcIP6.Netaddr(), test.p.DstIP6.Netaddr(), test.p.DstPort) + got = acl.CheckTCP(test.p.Src.IP, test.p.Dst.IP, test.p.Dst.Port) } if test.want != got { t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p) @@ -345,19 +345,19 @@ func TestOmitDropLogging(t *testing.T) { }, { name: "v4_multicast_out_low", - pkt: &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("224.0.0.0")}, + pkt: &packet.Parsed{IPVersion: 4, Dst: mustIPPort("224.0.0.0:0")}, dir: out, want: true, }, { name: "v4_multicast_out_high", - pkt: &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("239.255.255.255")}, + pkt: &packet.Parsed{IPVersion: 4, Dst: mustIPPort("239.255.255.255:0")}, dir: out, want: true, }, { name: "v4_link_local_unicast", - pkt: &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("169.254.1.2")}, + pkt: &packet.Parsed{IPVersion: 4, Dst: mustIPPort("169.254.1.2:0")}, dir: out, want: true, }, @@ -387,18 +387,16 @@ func parsed(proto packet.IPProto, src, dst string, sport, dport uint16) packet.P var ret packet.Parsed ret.Decode(dummyPacket) ret.IPProto = proto - ret.SrcPort = sport - ret.DstPort = dport + ret.Src.IP = sip + ret.Src.Port = sport + ret.Dst.IP = dip + ret.Dst.Port = dport ret.TCPFlags = packet.TCPSyn if sip.Is4() { ret.IPVersion = 4 - ret.SrcIP4 = packet.IP4FromNetaddr(sip) - ret.DstIP4 = packet.IP4FromNetaddr(dip) } else { ret.IPVersion = 6 - ret.SrcIP6 = packet.IP6FromNetaddr(sip) - ret.DstIP6 = packet.IP6FromNetaddr(dip) } return ret @@ -407,8 +405,8 @@ func parsed(proto packet.IPProto, src, dst string, sport, dport uint16) packet.P func raw6(proto packet.IPProto, src, dst string, sport, dport uint16, trimLen int) []byte { u := packet.UDP6Header{ IP6Header: packet.IP6Header{ - SrcIP: packet.IP6FromNetaddr(mustIP(src)), - DstIP: packet.IP6FromNetaddr(mustIP(dst)), + Src: mustIP(src), + Dst: mustIP(dst), }, SrcPort: sport, DstPort: dport, @@ -436,8 +434,8 @@ func raw6(proto packet.IPProto, src, dst string, sport, dport uint16, trimLen in func raw4(proto packet.IPProto, src, dst string, sport, dport uint16, trimLength int) []byte { u := packet.UDP4Header{ IP4Header: packet.IP4Header{ - SrcIP: packet.IP4FromNetaddr(mustIP(src)), - DstIP: packet.IP4FromNetaddr(mustIP(dst)), + Src: mustIP(src), + Dst: mustIP(dst), }, SrcPort: sport, DstPort: dport, @@ -488,12 +486,12 @@ func parseHexPkt(t *testing.T, h string) *packet.Parsed { return p } -func mustIP4(s string) packet.IP4 { - ip, err := netaddr.ParseIP(s) +func mustIPPort(s string) netaddr.IPPort { + ipp, err := netaddr.ParseIPPort(s) if err != nil { panic(err) } - return packet.IP4FromNetaddr(ip) + return ipp } func pfx(strs ...string) (ret []netaddr.IPPrefix) { diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index 27a976ab6..6dbe6b5bd 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -9,6 +9,7 @@ import ( "strings" "inet.af/netaddr" + "tailscale.com/net/packet" ) // PortRange is a range of TCP and UDP ports. @@ -71,3 +72,46 @@ func (m Match) String() string { } return fmt.Sprintf("%v=>%v", ss, ds) } + +type matches []Match + +func (ms matches) match(q *packet.Parsed) bool { + for _, m := range ms { + if !ipInList(q.Src.IP, m.Srcs) { + continue + } + for _, dst := range m.Dsts { + if !dst.Net.Contains(q.Dst.IP) { + continue + } + if !dst.Ports.contains(q.Dst.Port) { + continue + } + return true + } + } + return false +} + +func (ms matches) matchIPsOnly(q *packet.Parsed) bool { + for _, m := range ms { + if !ipInList(q.Src.IP, m.Srcs) { + continue + } + for _, dst := range m.Dsts { + if dst.Net.Contains(q.Dst.IP) { + return true + } + } + } + return false +} + +func ipInList(ip netaddr.IP, netlist []netaddr.IPPrefix) bool { + for _, net := range netlist { + if net.Contains(ip) { + return true + } + } + return false +} diff --git a/wgengine/filter/match4.go b/wgengine/filter/match4.go deleted file mode 100644 index 6e301ae80..000000000 --- a/wgengine/filter/match4.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package filter - -import ( - "fmt" - "math/bits" - "strings" - - "inet.af/netaddr" - "tailscale.com/net/packet" -) - -type net4 struct { - ip packet.IP4 - mask packet.IP4 -} - -func net4FromIPPrefix(pfx netaddr.IPPrefix) net4 { - if !pfx.IP.Is4() { - panic("net4FromIPPrefix given non-ipv4 prefix") - } - return net4{ - ip: packet.IP4FromNetaddr(pfx.IP), - mask: netmask4(pfx.Bits), - } -} - -func nets4FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net4) { - for _, pfx := range pfxs { - if pfx.IP.Is4() { - ret = append(ret, net4FromIPPrefix(pfx)) - } - } - return ret -} - -func (n net4) Contains(ip packet.IP4) bool { - return (n.ip & n.mask) == (ip & n.mask) -} - -func (n net4) Bits() int { - return 32 - bits.TrailingZeros32(uint32(n.mask)) -} - -func (n net4) String() string { - b := n.Bits() - if b == 32 { - return n.ip.String() - } else if b == 0 { - return "*" - } else { - return fmt.Sprintf("%s/%d", n.ip, b) - } -} - -type npr4 struct { - net net4 - ports PortRange -} - -func (npr npr4) String() string { - return fmt.Sprintf("%s:%s", npr.net, npr.ports) -} - -type match4 struct { - srcs []net4 - dsts []npr4 -} - -type matches4 []match4 - -func (ms matches4) String() string { - var b strings.Builder - for _, m := range ms { - fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts) - } - return b.String() -} - -func newMatches4(ms []Match) (ret matches4) { - for _, m := range ms { - var m4 match4 - for _, src := range m.Srcs { - if src.IP.Is4() { - m4.srcs = append(m4.srcs, net4FromIPPrefix(src)) - } - } - for _, dst := range m.Dsts { - if dst.Net.IP.Is4() { - m4.dsts = append(m4.dsts, npr4{net4FromIPPrefix(dst.Net), dst.Ports}) - } - } - if len(m4.srcs) > 0 && len(m4.dsts) > 0 { - ret = append(ret, m4) - } - } - return ret -} - -// match returns whether q's source IP and destination IP:port match -// any of ms. -func (ms matches4) match(q *packet.Parsed) bool { - for _, m := range ms { - if !ip4InList(q.SrcIP4, m.srcs) { - continue - } - for _, dst := range m.dsts { - if !dst.net.Contains(q.DstIP4) { - continue - } - if !dst.ports.contains(q.DstPort) { - continue - } - return true - } - } - return false -} - -// matchIPsOnly returns whether q's source and destination IP match -// any of ms. -func (ms matches4) matchIPsOnly(q *packet.Parsed) bool { - for _, m := range ms { - if !ip4InList(q.SrcIP4, m.srcs) { - continue - } - for _, dst := range m.dsts { - if dst.net.Contains(q.DstIP4) { - return true - } - } - } - return false -} - -func netmask4(bits uint8) packet.IP4 { - b := ^uint32((1 << (32 - bits)) - 1) - return packet.IP4(b) -} - -func ip4InList(ip packet.IP4, netlist []net4) bool { - for _, net := range netlist { - if net.Contains(ip) { - return true - } - } - return false -} diff --git a/wgengine/filter/match6.go b/wgengine/filter/match6.go deleted file mode 100644 index 1558c83ed..000000000 --- a/wgengine/filter/match6.go +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package filter - -import ( - "fmt" - "math/bits" - "strings" - - "inet.af/netaddr" - "tailscale.com/net/packet" -) - -type net6 struct { - ip packet.IP6 - mask packet.IP6 -} - -func net6FromIPPrefix(pfx netaddr.IPPrefix) net6 { - if !pfx.IP.Is6() { - panic("net6FromIPPrefix given non-ipv6 prefix") - } - var mask packet.IP6 - if pfx.Bits > 64 { - mask.Hi = ^uint64(0) - mask.Lo = (^uint64(0) << (128 - pfx.Bits)) - } else { - mask.Hi = (^uint64(0) << (64 - pfx.Bits)) - } - - return net6{ - ip: packet.IP6FromNetaddr(pfx.IP), - mask: mask, - } -} - -func nets6FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net6) { - for _, pfx := range pfxs { - if pfx.IP.Is6() { - ret = append(ret, net6FromIPPrefix(pfx)) - } - } - return ret -} - -func (n net6) Contains(ip packet.IP6) bool { - // This is equivalent to the more straightforward implementation: - // ((n.ip.Hi & n.mask.Hi) == (ip.Hi & n.mask.Hi) && - // (n.ip.Lo & n.mask.Lo) == (ip.Lo & n.mask.Lo)) - // - // This implementation runs significantly faster because it - // eliminates branches and minimizes the required - // bit-twiddling. - a := (n.ip.Hi ^ ip.Hi) & n.mask.Hi - b := (n.ip.Lo ^ ip.Lo) & n.mask.Lo - return (a | b) == 0 -} - -func (n net6) Bits() int { - return 128 - bits.TrailingZeros64(n.mask.Hi) - bits.TrailingZeros64(n.mask.Lo) -} - -func (n net6) String() string { - switch n.Bits() { - case 128: - return n.ip.String() - case 0: - return "*" - default: - return fmt.Sprintf("%s/%d", n.ip, n.Bits()) - } -} - -type npr6 struct { - net net6 - ports PortRange -} - -func (npr npr6) String() string { - return fmt.Sprintf("%s:%s", npr.net, npr.ports) -} - -type match6 struct { - srcs []net6 - dsts []npr6 -} - -type matches6 []match6 - -func (ms matches6) String() string { - var b strings.Builder - for _, m := range ms { - fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts) - } - return b.String() -} - -func newMatches6(ms []Match) (ret matches6) { - for _, m := range ms { - var m6 match6 - for _, src := range m.Srcs { - if src.IP.Is6() { - m6.srcs = append(m6.srcs, net6FromIPPrefix(src)) - } - } - for _, dst := range m.Dsts { - if dst.Net.IP.Is6() { - m6.dsts = append(m6.dsts, npr6{net6FromIPPrefix(dst.Net), dst.Ports}) - } - } - if len(m6.srcs) > 0 && len(m6.dsts) > 0 { - ret = append(ret, m6) - } - } - return ret -} - -func (ms matches6) match(q *packet.Parsed) bool { -outer: - for i := range ms { - srcs := ms[i].srcs - for j := range srcs { - if srcs[j].Contains(q.SrcIP6) { - dsts := ms[i].dsts - for k := range dsts { - if dsts[k].net.Contains(q.DstIP6) && dsts[k].ports.contains(q.DstPort) { - return true - } - } - // We hit on src, but missed on all - // dsts. No need to try other srcs, - // they'll never fully match. - continue outer - } - } - } - return false -} - -func (ms matches6) matchIPsOnly(q *packet.Parsed) bool { -outer: - for i := range ms { - srcs := ms[i].srcs - for j := range srcs { - if srcs[j].Contains(q.SrcIP6) { - dsts := ms[i].dsts - for k := range dsts { - if dsts[k].net.Contains(q.DstIP6) { - return true - } - } - // We hit on src, but missed on all - // dsts. No need to try other srcs, - // they'll never fully match. - continue outer - } - } - } - return false -} - -func ip6InList(ip packet.IP6, netlist []net6) bool { - for _, net := range netlist { - if net.Contains(ip) { - return true - } - } - return false -} diff --git a/wgengine/filter/match6_test.go b/wgengine/filter/match6_test.go deleted file mode 100644 index d96b3794a..000000000 --- a/wgengine/filter/match6_test.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package filter - -import "testing" - -// Verifies that the fast bit-twiddling implementation of Contains -// works the same as the easy-to-read implementation. Since we can't -// sensibly check it on 128 bits, the test runs over 4-bit -// "IPs". Bit-twiddling is the same at any width, so this adequately -// proves that the implementations are equivalent. -func TestOptimizedContains(t *testing.T) { - for ipHi := 0; ipHi < 0xf; ipHi++ { - for ipLo := 0; ipLo < 0xf; ipLo++ { - for nIPHi := 0; nIPHi < 0xf; nIPHi++ { - for nIPLo := 0; nIPLo < 0xf; nIPLo++ { - for maskHi := 0; maskHi < 0xf; maskHi++ { - for maskLo := 0; maskLo < 0xf; maskLo++ { - - a := (nIPHi ^ ipHi) & maskHi - b := (nIPLo ^ ipLo) & maskLo - got := (a | b) == 0 - - want := ((nIPHi&maskHi) == (ipHi&maskHi) && (nIPLo&maskLo) == (ipLo&maskLo)) - - if got != want { - t.Errorf("mask %1x%1x/%1x%1x %1x%1x got=%v want=%v", nIPHi, nIPLo, maskHi, maskLo, ipHi, ipLo, got, want) - } - } - } - } - } - } - } -} diff --git a/wgengine/tstun/tun.go b/wgengine/tstun/tun.go index 193e51444..a1a9a2a3d 100644 --- a/wgengine/tstun/tun.go +++ b/wgengine/tstun/tun.go @@ -16,6 +16,7 @@ import ( "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" + "inet.af/netaddr" "tailscale.com/net/packet" "tailscale.com/types/logger" "tailscale.com/wgengine/filter" @@ -67,8 +68,7 @@ type TUN struct { lastActivityAtomic int64 // unix seconds of last send or receive - destIPActivity4 atomic.Value // of map[packet.IP4]func() - destIPActivity6 atomic.Value // of map[packet.IP6]func() + destIPActivity atomic.Value // of map[netaddr.IP]func() // buffer stores the oldest unconsumed packet from tdev. // It is made a static buffer in order to avoid allocations. @@ -137,9 +137,8 @@ func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN { // destination (the map keys). // // The map ownership passes to the TUN. It must be non-nil. -func (t *TUN) SetDestIPActivityFuncs(m4 map[packet.IP4]func(), m6 map[packet.IP6]func()) { - t.destIPActivity4.Store(m4) - t.destIPActivity6.Store(m6) +func (t *TUN) SetDestIPActivityFuncs(m map[netaddr.IP]func()) { + t.destIPActivity.Store(m) } func (t *TUN) Close() error { @@ -284,18 +283,9 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) { defer parsedPacketPool.Put(p) p.Decode(buf[offset : offset+n]) - switch p.IPVersion { - case 4: - if m, ok := t.destIPActivity4.Load().(map[packet.IP4]func()); ok { - if fn := m[p.DstIP4]; fn != nil { - fn() - } - } - case 6: - if m, ok := t.destIPActivity6.Load().(map[packet.IP6]func()); ok { - if fn := m[p.DstIP6]; fn != nil { - fn() - } + if m, ok := t.destIPActivity.Load().(map[netaddr.IP]func()); ok { + if fn := m[p.Dst.IP]; fn != nil { + fn() } } diff --git a/wgengine/tstun/tun_test.go b/wgengine/tstun/tun_test.go index f6e363f18..9eeeeeef1 100644 --- a/wgengine/tstun/tun_test.go +++ b/wgengine/tstun/tun_test.go @@ -20,12 +20,20 @@ import ( "tailscale.com/wgengine/filter" ) -func udp(src, dst packet.IP4, sport, dport uint16) []byte { +func udp4(src, dst string, sport, dport uint16) []byte { + sip, err := netaddr.ParseIP(src) + if err != nil { + panic(err) + } + dip, err := netaddr.ParseIP(dst) + if err != nil { + panic(err) + } header := &packet.UDP4Header{ IP4Header: packet.IP4Header{ - SrcIP: src, - DstIP: dst, - IPID: 0, + Src: sip, + Dst: dip, + IPID: 0, }, SrcPort: sport, DstPort: dport, @@ -252,12 +260,12 @@ func TestFilter(t *testing.T) { }{ {"junk_in", in, true, []byte("\x45not a valid IPv4 packet")}, {"junk_out", out, true, []byte("\x45not a valid IPv4 packet")}, - {"bad_port_in", in, true, udp(0x05060708, 0x01020304, 22, 22)}, - {"bad_port_out", out, false, udp(0x01020304, 0x05060708, 22, 22)}, - {"bad_ip_in", in, true, udp(0x08010101, 0x01020304, 89, 89)}, - {"bad_ip_out", out, false, udp(0x01020304, 0x08010101, 98, 98)}, - {"good_packet_in", in, false, udp(0x05060708, 0x01020304, 89, 89)}, - {"good_packet_out", out, false, udp(0x01020304, 0x05060708, 98, 98)}, + {"bad_port_in", in, true, udp4("5.6.7.8", "1.2.3.4", 22, 22)}, + {"bad_port_out", out, false, udp4("1.2.3.4", "5.6.7.8", 22, 22)}, + {"bad_ip_in", in, true, udp4("8.1.1.1", "1.2.3.4", 89, 89)}, + {"bad_ip_out", out, false, udp4("1.2.3.4", "8.1.1.1", 98, 98)}, + {"good_packet_in", in, false, udp4("5.6.7.8", "1.2.3.4", 89, 89)}, + {"good_packet_out", out, false, udp4("1.2.3.4", "5.6.7.8", 98, 98)}, } // A reader on the other end of the TUN. @@ -337,7 +345,7 @@ func BenchmarkWrite(b *testing.B) { ftun, tun := newFakeTUN(b.Logf, true) defer tun.Close() - packet := udp(0x05060708, 0x01020304, 89, 89) + packet := udp4("5.6.7.8", "1.2.3.4", 89, 89) for i := 0; i < b.N; i++ { _, err := ftun.Write(packet, 0) if err != nil { diff --git a/wgengine/userspace.go b/wgengine/userspace.go index e44401f2d..9064f524a 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -8,7 +8,6 @@ import ( "bufio" "bytes" "context" - "encoding/binary" "errors" "fmt" "io" @@ -58,10 +57,9 @@ import ( // discovery. const minimalMTU = 1280 -const ( - magicDNSIP = 0x64646464 // 100.100.100.100 - magicDNSPort = 53 -) +const magicDNSPort = 53 + +var magicDNSIP = netaddr.IPv4(100, 100, 100, 100) // Lazy wireguard-go configuration parameters. const ( @@ -99,19 +97,17 @@ type userspaceEngine struct { // localAddrs is the set of IP addresses assigned to the local // tunnel interface. It's used to reflect local packets // incorrectly sent to us. - localAddrs atomic.Value // of map[packet.IP4]bool + localAddrs atomic.Value // of map[netaddr.IP]bool - wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below - lastCfgFull wgcfg.Config - lastRouterSig string // of router.Config - lastEngineSigFull string // of full wireguard config - lastEngineSigTrim string // of trimmed wireguard config - recvActivityAt map[tailcfg.DiscoKey]time.Time - trimmedDisco map[tailcfg.DiscoKey]bool // set of disco keys of peers currently excluded from wireguard config - sentActivityAt4 map[packet.IP4]*int64 // value is atomic int64 of unixtime - destIPActivityFuncs4 map[packet.IP4]func() - sentActivityAt6 map[packet.IP6]*int64 // value is atomic int64 of unixtime - destIPActivityFuncs6 map[packet.IP6]func() + wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below + lastCfgFull wgcfg.Config + lastRouterSig string // of router.Config + lastEngineSigFull string // of full wireguard config + lastEngineSigTrim string // of trimmed wireguard config + recvActivityAt map[tailcfg.DiscoKey]time.Time + trimmedDisco map[tailcfg.DiscoKey]bool // set of disco keys of peers currently excluded from wireguard config + sentActivityAt map[netaddr.IP]*int64 // value is atomic int64 of unixtime + destIPActivityFuncs map[netaddr.IP]func() mu sync.Mutex // guards following; see lock order comment below closing bool // Close was called (even if we're still closing) @@ -208,7 +204,7 @@ func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) { resolver: tsdns.NewResolver(rconf), pingers: make(map[wgcfg.Key]*pinger), } - e.localAddrs.Store(map[packet.IP4]bool{}) + e.localAddrs.Store(map[netaddr.IP]bool{}) e.linkState, _ = getLinkState() logf("link state: %+v", e.linkState) @@ -399,7 +395,7 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.TUN) fil return filter.Drop } - if (runtime.GOOS == "darwin" || runtime.GOOS == "ios") && e.isLocalAddr(p.DstIP4) { + if (runtime.GOOS == "darwin" || runtime.GOOS == "ios") && e.isLocalAddr(p.Dst.IP) { // macOS NetworkExtension directs packets destined to the // tunnel's local IP address into the tunnel, instead of // looping back within the kernel network stack. We have to @@ -412,8 +408,8 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.TUN) fil return filter.Accept } -func (e *userspaceEngine) isLocalAddr(ip packet.IP4) bool { - localAddrs, ok := e.localAddrs.Load().(map[packet.IP4]bool) +func (e *userspaceEngine) isLocalAddr(ip netaddr.IP) bool { + localAddrs, ok := e.localAddrs.Load().(map[netaddr.IP]bool) if !ok { e.logf("[unexpected] e.localAddrs was nil, can't check for loopback packet") return false @@ -423,10 +419,10 @@ func (e *userspaceEngine) isLocalAddr(ip packet.IP4) bool { // handleDNS is an outbound pre-filter resolving Tailscale domains. func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.TUN) filter.Response { - if p.DstIP4 == magicDNSIP && p.DstPort == magicDNSPort && p.IPProto == packet.UDP { + if p.Dst.IP == magicDNSIP && p.Dst.Port == magicDNSPort && p.IPProto == packet.UDP { request := tsdns.Packet{ Payload: append([]byte(nil), p.Payload()...), - Addr: netaddr.IPPort{IP: p.SrcIP4.Netaddr(), Port: p.SrcPort}, + Addr: netaddr.IPPort{IP: p.Src.IP, Port: p.Src.Port}, } err := e.resolver.EnqueueRequest(request) if err != nil { @@ -451,8 +447,8 @@ func (e *userspaceEngine) pollResolver() { h := packet.UDP4Header{ IP4Header: packet.IP4Header{ - SrcIP: packet.IP4(magicDNSIP), - DstIP: packet.IP4FromNetaddr(resp.Addr.IP), + Src: magicDNSIP, + Dst: resp.Addr.IP, }, SrcPort: magicDNSPort, DstPort: resp.Addr.Port, @@ -489,7 +485,7 @@ func (p *pinger) close() { <-p.done } -func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, srcIP packet.IP4) { +func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, srcIP netaddr.IP) { defer func() { p.e.mu.Lock() if p.e.pingers[peerKey] == p { @@ -502,7 +498,7 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src header := packet.ICMP4Header{ IP4Header: packet.IP4Header{ - SrcIP: srcIP, + Src: srcIP, }, Type: packet.ICMP4EchoRequest, Code: packet.ICMP4NoCode, @@ -515,7 +511,7 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src const stopAfter = 3 * time.Second start := time.Now() - var dstIPs []packet.IP4 + var dstIPs []netaddr.IP for _, ip := range ips { if ip.Is6() { // This code is only used for legacy (pre-discovery) @@ -524,7 +520,7 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src // work. continue } - dstIPs = append(dstIPs, packet.IP4FromNetaddr(netaddr.IPFrom16(ip.Addr))) + dstIPs = append(dstIPs, netaddr.IPFrom16(ip.Addr)) } payload := []byte("magicsock_spray") // no meaning @@ -542,7 +538,7 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src return } for _, dstIP := range dstIPs { - header.DstIP = dstIP + header.Dst = dstIP // InjectOutbound take ownership of the packet, so we allocate. b := packet.Generate(&header, payload) p.e.tundev.InjectOutbound(b) @@ -560,15 +556,15 @@ func (p *pinger) run(ctx context.Context, peerKey wgcfg.Key, ips []wgcfg.IP, src // have advertised discovery keys. func (e *userspaceEngine) pinger(peerKey wgcfg.Key, ips []wgcfg.IP) { e.logf("[v1] generating initial ping traffic to %s (%v)", peerKey.ShortString(), ips) - var srcIP packet.IP4 + var srcIP netaddr.IP e.wgLock.Lock() if len(e.lastCfgFull.Addresses) > 0 { - srcIP = packet.IP4FromNetaddr(netaddr.IPFrom16(e.lastCfgFull.Addresses[0].IP.Addr)) + srcIP = netaddr.IPFrom16(e.lastCfgFull.Addresses[0].IP.Addr) } e.wgLock.Unlock() - if srcIP == 0 { + if srcIP.IsZero() { e.logf("generating initial ping traffic: no source IP") return } @@ -694,17 +690,8 @@ func (e *userspaceEngine) isActiveSince(dk tailcfg.DiscoKey, ip wgcfg.IP, t time if e.recvActivityAt[dk].After(t) { return true } - var ( - timePtr *int64 - ok bool - ) - if ip.Is4() { - pip := packet.IP4(binary.BigEndian.Uint32(ip.Addr[12:])) - timePtr, ok = e.sentActivityAt4[pip] - } else { - pip := packet.IP6FromRaw16(ip.Addr) - timePtr, ok = e.sentActivityAt6[pip] - } + pip := netaddr.IPFrom16(ip.Addr) + timePtr, ok := e.sentActivityAt[pip] if !ok { return false } @@ -845,14 +832,10 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackDisco []tailcfg.DiscoKey } e.recvActivityAt = mr - oldTime4 := e.sentActivityAt4 - e.sentActivityAt4 = make(map[packet.IP4]*int64, len(oldTime4)) - oldFunc4 := e.destIPActivityFuncs4 - e.destIPActivityFuncs4 = make(map[packet.IP4]func(), len(oldFunc4)) - oldTime6 := e.sentActivityAt6 - e.sentActivityAt6 = make(map[packet.IP6]*int64, len(oldTime6)) - oldFunc6 := e.destIPActivityFuncs6 - e.destIPActivityFuncs6 = make(map[packet.IP6]func(), len(oldFunc6)) + oldTime := e.sentActivityAt + e.sentActivityAt = make(map[netaddr.IP]*int64, len(oldTime)) + oldFunc := e.destIPActivityFuncs + e.destIPActivityFuncs = make(map[netaddr.IP]func(), len(oldFunc)) updateFn := func(timePtr *int64) func() { return func() { @@ -877,35 +860,20 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackDisco []tailcfg.DiscoKey } for _, wip := range trackIPs { - if wip.Is4() { - pip := packet.IP4(binary.BigEndian.Uint32(wip.Addr[12:])) - timePtr := oldTime4[pip] - if timePtr == nil { - timePtr = new(int64) - } - e.sentActivityAt4[pip] = timePtr - - fn := oldFunc4[pip] - if fn == nil { - fn = updateFn(timePtr) - } - e.destIPActivityFuncs4[pip] = fn - } else { - pip := packet.IP6FromRaw16(wip.Addr) - timePtr := oldTime6[pip] - if timePtr == nil { - timePtr = new(int64) - } - e.sentActivityAt6[pip] = timePtr - - fn := oldFunc6[pip] - if fn == nil { - fn = updateFn(timePtr) - } - e.destIPActivityFuncs6[pip] = fn + pip := netaddr.IPFrom16(wip.Addr) + timePtr := oldTime[pip] + if timePtr == nil { + timePtr = new(int64) } + e.sentActivityAt[pip] = timePtr + + fn := oldFunc[pip] + if fn == nil { + fn = updateFn(timePtr) + } + e.destIPActivityFuncs[pip] = fn } - e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs4, e.destIPActivityFuncs6) + e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs) } func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) error { @@ -913,13 +881,9 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) panic("routerCfg must not be nil") } - localAddrs := map[packet.IP4]bool{} + localAddrs := map[netaddr.IP]bool{} for _, addr := range routerCfg.LocalAddrs { - // TODO: ipv6 - if !addr.IP.Is4() { - continue - } - localAddrs[packet.IP4FromNetaddr(addr.IP)] = true + localAddrs[addr.IP] = true } e.localAddrs.Store(localAddrs)