From 21ed31e33a606977112c51601bf282480ba5b784 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 15 Jun 2024 18:20:17 -0700 Subject: [PATCH] wgengine/filter: use NewContainsIPFunc for Srcs matches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit NewContainsIPFunc returns a contains matcher optimized for its input. Use that instead of what this did before, always doing a test over each of a list of netip.Prefixes. goos: darwin goarch: arm64 pkg: tailscale.com/wgengine/filter │ before │ after │ │ sec/op │ sec/op vs base │ FilterMatch/file1-8 32.60n ± 1% 18.87n ± 1% -42.12% (p=0.000 n=10) Updates #12486 Change-Id: I8f902bc064effb431e5b46751115942104ff6531 Signed-off-by: Brad Fitzpatrick --- wgengine/filter/filter.go | 33 +++++++++++++++++++++++---------- wgengine/filter/filter_clone.go | 9 +++++---- wgengine/filter/filter_test.go | 28 ++++++++++++++++------------ wgengine/filter/match.go | 24 ++++++++---------------- wgengine/filter/tailcfg.go | 3 +++ 5 files changed, 55 insertions(+), 42 deletions(-) diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 082d8a0f5..0e01c848d 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -16,10 +16,12 @@ import ( "tailscale.com/net/flowtrack" "tailscale.com/net/netaddr" "tailscale.com/net/packet" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/tstime/rate" "tailscale.com/types/ipproto" "tailscale.com/types/logger" + "tailscale.com/types/views" "tailscale.com/util/mak" ) @@ -30,12 +32,12 @@ type Filter struct { // this node. All packets coming in over tailscale must have a // destination within local, regardless of the policy filter // below. - local *netipx.IPSet + local func(netip.Addr) bool // logIPs is the set of IPs that are allowed to appear in flow // logs. If a packet is to or from an IP not in logIPs, it will // never be logged. - logIPs *netipx.IPSet + logIPs func(netip.Addr) bool // matches4 and matches6 are lists of match->action rules // applied to all packets arriving over tailscale @@ -172,7 +174,7 @@ func NewShieldsUpFilter(localNets *netipx.IPSet, logIPs *netipx.IPSet, shareStat // by matches. If shareStateWith is non-nil, the returned filter // shares state with the previous one, to enable changing rules at // runtime without breaking existing stateful flows. -func New(matches []Match, localNets *netipx.IPSet, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter { +func New(matches []Match, localNets, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter { var state *filterState if shareStateWith != nil { state = shareStateWith.state @@ -181,14 +183,22 @@ func New(matches []Match, localNets *netipx.IPSet, logIPs *netipx.IPSet, shareSt lru: &flowtrack.Cache[struct{}]{MaxEntries: lruMax}, } } + + containsFunc := func(s *netipx.IPSet) func(netip.Addr) bool { + if s == nil { + return tsaddr.FalseContainsIPFunc() + } + return tsaddr.NewContainsIPFunc(views.SliceOf(s.Prefixes())) + } + f := &Filter{ logf: logf, matches4: matchesFamily(matches, netip.Addr.Is4), matches6: matchesFamily(matches, netip.Addr.Is6), cap4: capMatchesFunc(matches, netip.Addr.Is4), cap6: capMatchesFunc(matches, netip.Addr.Is6), - local: localNets, - logIPs: logIPs, + local: containsFunc(localNets), + logIPs: containsFunc(logIPs), state: state, } return f @@ -206,12 +216,14 @@ func matchesFamily(ms matches, keep func(netip.Addr) bool) matches { retm.Srcs = append(retm.Srcs, src) } } + for _, dst := range m.Dsts { if keep(dst.Net.Addr()) { retm.Dsts = append(retm.Dsts, dst) } } if len(retm.Srcs) > 0 && len(retm.Dsts) > 0 { + retm.SrcsContains = tsaddr.NewContainsIPFunc(views.SliceOf(retm.Srcs)) ret = append(ret, retm) } } @@ -233,6 +245,7 @@ func capMatchesFunc(ms matches, keep func(netip.Addr) bool) matches { } } if len(retm.Srcs) > 0 { + retm.SrcsContains = tsaddr.NewContainsIPFunc(views.SliceOf(retm.Srcs)) ret = append(ret, retm) } } @@ -268,7 +281,7 @@ func init() { } func (f *Filter) logRateLimit(runflags RunFlags, q *packet.Parsed, dir direction, r Response, why string) { - if !f.loggingAllowed(q) { + if runflags == 0 || !f.loggingAllowed(q) { return } @@ -345,7 +358,7 @@ func (f *Filter) CapsWithValues(srcIP, dstIP netip.Addr) tailcfg.PeerCapMap { } var out tailcfg.PeerCapMap for _, m := range mm { - if !ipInList(srcIP, m.Srcs) { + if !m.SrcsContains(srcIP) { continue } for _, cm := range m.Caps { @@ -418,7 +431,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { // A compromised peer could try to send us packets for // destinations we didn't explicitly advertise. This check is to // prevent that. - if !f.local.Contains(q.Dst.Addr()) { + if !f.local(q.Dst.Addr()) { return Drop, "destination not allowed" } @@ -478,7 +491,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { // A compromised peer could try to send us packets for // destinations we didn't explicitly advertise. This check is to // prevent that. - if !f.local.Contains(q.Dst.Addr()) { + if !f.local(q.Dst.Addr()) { return Drop, "destination not allowed" } @@ -604,7 +617,7 @@ func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response { // loggingAllowed reports whether p can appear in logs at all. func (f *Filter) loggingAllowed(p *packet.Parsed) bool { - return f.logIPs.Contains(p.Src.Addr()) && f.logIPs.Contains(p.Dst.Addr()) + return f.logIPs(p.Src.Addr()) && f.logIPs(p.Dst.Addr()) } // omitDropLogging reports whether packet p, which has already been diff --git a/wgengine/filter/filter_clone.go b/wgengine/filter/filter_clone.go index 97366d83c..adeaf6efe 100644 --- a/wgengine/filter/filter_clone.go +++ b/wgengine/filter/filter_clone.go @@ -34,10 +34,11 @@ func (src *Match) Clone() *Match { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _MatchCloneNeedsRegeneration = Match(struct { - IPProto []ipproto.Proto - Srcs []netip.Prefix - Dsts []NetPortRange - Caps []CapMatch + IPProto []ipproto.Proto + Srcs []netip.Prefix + SrcsContains func(netip.Addr) bool + Dsts []NetPortRange + Caps []CapMatch }{}) // Clone makes a deep copy of CapMatch. diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index dc9932db3..39eca5f66 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -16,6 +16,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "go4.org/netipx" xmaps "golang.org/x/exp/maps" "tailscale.com/net/packet" @@ -25,6 +26,7 @@ import ( "tailscale.com/tstime/rate" "tailscale.com/types/ipproto" "tailscale.com/types/logger" + "tailscale.com/types/views" "tailscale.com/util/must" ) @@ -40,9 +42,10 @@ func m(srcs []netip.Prefix, dsts []NetPortRange, protos ...ipproto.Proto) Match protos = defaultProtos } return Match{ - IPProto: protos, - Srcs: srcs, - Dsts: dsts, + IPProto: protos, + Srcs: srcs, + SrcsContains: tsaddr.NewContainsIPFunc(views.SliceOf(srcs)), + Dsts: dsts, } } @@ -436,11 +439,11 @@ func TestLoggingPrivacy(t *testing.T) { logged = true } - var logB netipx.IPSetBuilder - logB.AddPrefix(netip.MustParsePrefix("100.64.0.0/10")) - logB.AddPrefix(tsaddr.TailscaleULARange()) f := newFilter(logf) - f.logIPs, _ = logB.IPSet() + f.logIPs = tsaddr.NewContainsIPFunc(views.SliceOf([]netip.Prefix{ + tsaddr.CGNATRange(), + tsaddr.TailscaleULARange(), + })) var ( ts4 = netip.AddrPortFrom(tsaddr.CGNATRange().Addr().Next(), 1234) @@ -820,11 +823,12 @@ func TestMatchesFromFilterRules(t *testing.T) { if err != nil { t.Fatal(err) } - - compareIP := cmp.Comparer(func(a, b netip.Addr) bool { return a == b }) - compareIPPrefix := cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }) - - if diff := cmp.Diff(got, tt.want, compareIP, compareIPPrefix); diff != "" { + cmpOpts := []cmp.Option{ + cmp.Comparer(func(a, b netip.Addr) bool { return a == b }), + cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }), + cmpopts.IgnoreFields(Match{}, ".SrcsContains"), + } + if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" { t.Errorf("wrong (-got+want)\n%s", diff) } }) diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index b0ebfbb41..7bb063d7a 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -66,10 +66,11 @@ type CapMatch struct { // Match matches packets from any IP address in Srcs to any ip:port in // Dsts. type Match struct { - IPProto []ipproto.Proto // required set (no default value at this layer) - Srcs []netip.Prefix - Dsts []NetPortRange // optional, if Srcs match - Caps []CapMatch // optional, if Srcs match + IPProto []ipproto.Proto // required set (no default value at this layer) + Srcs []netip.Prefix + SrcsContains func(netip.Addr) bool `json:"-"` // report whether Addr is in Srcs + Dsts []NetPortRange // optional, if Srcs match + Caps []CapMatch // optional, if Srcs match } func (m Match) String() string { @@ -104,7 +105,7 @@ func (ms matches) match(q *packet.Parsed) bool { if !slices.Contains(m.IPProto, q.IPProto) { continue } - if !ipInList(q.Src.Addr(), m.Srcs) { + if !m.SrcsContains(q.Src.Addr()) { continue } for _, dst := range m.Dsts { @@ -122,7 +123,7 @@ func (ms matches) match(q *packet.Parsed) bool { func (ms matches) matchIPsOnly(q *packet.Parsed) bool { for _, m := range ms { - if !ipInList(q.Src.Addr(), m.Srcs) { + if !m.SrcsContains(q.Src.Addr()) { continue } for _, dst := range m.Dsts { @@ -142,7 +143,7 @@ func (ms matches) matchProtoAndIPsOnlyIfAllPorts(q *packet.Parsed) bool { if !slices.Contains(m.IPProto, q.IPProto) { continue } - if !ipInList(q.Src.Addr(), m.Srcs) { + if !m.SrcsContains(q.Src.Addr()) { continue } for _, dst := range m.Dsts { @@ -156,12 +157,3 @@ func (ms matches) matchProtoAndIPsOnlyIfAllPorts(q *packet.Parsed) bool { } return false } - -func ipInList(ip netip.Addr, netlist []netip.Prefix) bool { - for _, net := range netlist { - if net.Contains(ip) { - return true - } - } - return false -} diff --git a/wgengine/filter/tailcfg.go b/wgengine/filter/tailcfg.go index 9f64587d7..6bdb5b163 100644 --- a/wgengine/filter/tailcfg.go +++ b/wgengine/filter/tailcfg.go @@ -10,8 +10,10 @@ import ( "go4.org/netipx" "tailscale.com/net/netaddr" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/ipproto" + "tailscale.com/types/views" ) var defaultProtos = []ipproto.Proto{ @@ -61,6 +63,7 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) { } m.Srcs = append(m.Srcs, nets...) } + m.SrcsContains = tsaddr.NewContainsIPFunc(views.SliceOf(m.Srcs)) for _, d := range r.DstPorts { nets, err := parseIPSet(d.IP, d.Bits)