wgengine/netstack: add a per-client limit for in-flight TCP forwards
This is a fun one. Right now, when a client is connecting through a subnet router, here's roughly what happens: 1. The client initiates a connection to an IP address behind a subnet router, and sends a TCP SYN 2. The subnet router gets the SYN packet from netstack, and after running through acceptTCP, starts DialContext-ing the destination IP, without accepting the connection¹ 3. The client retransmits the SYN packet a few times while the dial is in progress, until either... 4. The subnet router successfully establishes a connection to the destination IP and sends the SYN-ACK back to the client, or... 5. The subnet router times out and sends a RST to the client. 6. If the connection was successful, the client ACKs the SYN-ACK it received, and traffic starts flowing As a result, the notification code in forwardTCP never notices when a new connection attempt is aborted, and it will wait until either the connection is established, or until the OS-level connection timeout is reached and it aborts. To mitigate this, add a per-client limit on how many in-flight TCP forwarding connections can be in-progress; after this, clients will see a similar behaviour to the global limit, where new connection attempts are aborted instead of waiting. This prevents a single misbehaving client from blocking all other clients of a subnet router by ensuring that it doesn't starve the global limiter. Also, bump the global limit again to a higher value. ¹ We can't accept the connection before establishing a connection to the remote server since otherwise we'd be opening the connection and then immediately closing it, which breaks a bunch of stuff; see #5503 for more details. Updates tailscale/corp#12184 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I76e7008ddd497303d75d473f534e32309c8a5144
This commit is contained in:
@ -4,14 +4,22 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/ipn/store/mem"
|
||||
"tailscale.com/metrics"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/net/tsdial"
|
||||
@ -455,3 +463,234 @@ func TestShouldProcessInbound(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func tcp4syn(tb testing.TB, src, dst netip.Addr, sport, dport uint16) []byte {
|
||||
ip := header.IPv4(make([]byte, header.IPv4MinimumSize+header.TCPMinimumSize))
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
Protocol: uint8(header.TCPProtocolNumber),
|
||||
TotalLength: header.IPv4MinimumSize + header.TCPMinimumSize,
|
||||
TTL: 64,
|
||||
SrcAddr: tcpip.AddrFrom4Slice(src.AsSlice()),
|
||||
DstAddr: tcpip.AddrFrom4Slice(dst.AsSlice()),
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
if !ip.IsChecksumValid() {
|
||||
tb.Fatal("test broken; packet has incorrect IP checksum")
|
||||
}
|
||||
|
||||
tcp := header.TCP(ip[header.IPv4MinimumSize:])
|
||||
tcp.Encode(&header.TCPFields{
|
||||
SrcPort: sport,
|
||||
DstPort: dport,
|
||||
SeqNum: 0,
|
||||
DataOffset: header.TCPMinimumSize,
|
||||
Flags: header.TCPFlagSyn,
|
||||
WindowSize: 65535,
|
||||
Checksum: 0,
|
||||
})
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.TCPProtocolNumber,
|
||||
tcpip.AddrFrom4Slice(src.AsSlice()),
|
||||
tcpip.AddrFrom4Slice(dst.AsSlice()),
|
||||
uint16(header.TCPMinimumSize),
|
||||
)
|
||||
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
|
||||
if !tcp.IsChecksumValid(tcpip.AddrFrom4Slice(src.AsSlice()), tcpip.AddrFrom4Slice(dst.AsSlice()), 0, 0) {
|
||||
tb.Fatal("test broken; packet has incorrect TCP checksum")
|
||||
}
|
||||
|
||||
return ip
|
||||
}
|
||||
|
||||
// makeHangDialer returns a dialer that notifies the returned channel when a
|
||||
// connection is dialed and then hangs until the test finishes.
|
||||
func makeHangDialer(tb testing.TB) (func(context.Context, string, string) (net.Conn, error), chan struct{}) {
|
||||
done := make(chan struct{})
|
||||
tb.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
|
||||
gotConn := make(chan struct{}, 1)
|
||||
fn := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// Signal that we have a new connection
|
||||
tb.Logf("hangDialer: called with network=%q address=%q", network, address)
|
||||
select {
|
||||
case gotConn <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
// Hang until the test is done.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
tb.Logf("context done")
|
||||
case <-done:
|
||||
tb.Logf("function completed")
|
||||
}
|
||||
return nil, fmt.Errorf("canceled")
|
||||
}
|
||||
return fn, gotConn
|
||||
}
|
||||
|
||||
// TestTCPForwardLimits verifies that the limits on the TCP forwarder work in a
|
||||
// success case (i.e. when we don't hit the limit).
|
||||
func TestTCPForwardLimits(t *testing.T) {
|
||||
envknob.Setenv("TS_DEBUG_NETSTACK", "true")
|
||||
impl := makeNetstack(t, func(impl *Impl) {
|
||||
impl.ProcessSubnets = true
|
||||
})
|
||||
|
||||
dialFn, gotConn := makeHangDialer(t)
|
||||
impl.forwardDialFunc = dialFn
|
||||
|
||||
prefs := ipn.NewPrefs()
|
||||
prefs.AdvertiseRoutes = []netip.Prefix{
|
||||
// This is the TEST-NET-1 IP block for use in documentation,
|
||||
// and should never actually be routable.
|
||||
netip.MustParsePrefix("192.0.2.0/24"),
|
||||
}
|
||||
impl.lb.Start(ipn.Options{
|
||||
LegacyMigrationPrefs: prefs,
|
||||
})
|
||||
impl.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress)
|
||||
|
||||
// Inject an "outbound" packet that's going to an IP address that times
|
||||
// out. We need to re-parse from a byte slice so that the internal
|
||||
// buffer in the packet.Parsed type is filled out.
|
||||
client := netip.MustParseAddr("100.101.102.103")
|
||||
destAddr := netip.MustParseAddr("192.0.2.1")
|
||||
pkt := tcp4syn(t, client, destAddr, 1234, 4567)
|
||||
var parsed packet.Parsed
|
||||
parsed.Decode(pkt)
|
||||
|
||||
// When injecting this packet, we want the outcome to be "drop
|
||||
// silently", which indicates that netstack is processing the
|
||||
// packet and not delivering it to the host system.
|
||||
if resp := impl.injectInbound(&parsed, impl.tundev); resp != filter.DropSilently {
|
||||
t.Errorf("got filter outcome %v, want filter.DropSilently", resp)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Wait until we have an in-flight outgoing connection.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timed out waiting for connection")
|
||||
case <-gotConn:
|
||||
t.Logf("got connection in progress")
|
||||
}
|
||||
|
||||
// Verify that we now have a single in-flight address in our map.
|
||||
impl.mu.Lock()
|
||||
inFlight := maps.Clone(impl.connsInFlightByClient)
|
||||
impl.mu.Unlock()
|
||||
|
||||
if got, ok := inFlight[client]; !ok || got != 1 {
|
||||
t.Errorf("expected 1 in-flight connection for %v, got: %v", client, inFlight)
|
||||
}
|
||||
|
||||
// Get the expvar statistics and verify that we're exporting the
|
||||
// correct metric.
|
||||
metrics := impl.ExpVar().(*metrics.Set)
|
||||
|
||||
const metricName = "gauge_tcp_forward_in_flight"
|
||||
if v := metrics.Get(metricName).String(); v != "1" {
|
||||
t.Errorf("got metric %q=%s, want 1", metricName, v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTCPForwardLimits_PerClient verifies that the per-client limit for TCP
|
||||
// forwarding works.
|
||||
func TestTCPForwardLimits_PerClient(t *testing.T) {
|
||||
envknob.Setenv("TS_DEBUG_NETSTACK", "true")
|
||||
|
||||
// Set our test override limits during this test.
|
||||
tstest.Replace(t, &maxInFlightConnectionAttemptsForTest, 2)
|
||||
tstest.Replace(t, &maxInFlightConnectionAttemptsPerClientForTest, 1)
|
||||
|
||||
impl := makeNetstack(t, func(impl *Impl) {
|
||||
impl.ProcessSubnets = true
|
||||
})
|
||||
|
||||
dialFn, gotConn := makeHangDialer(t)
|
||||
impl.forwardDialFunc = dialFn
|
||||
|
||||
prefs := ipn.NewPrefs()
|
||||
prefs.AdvertiseRoutes = []netip.Prefix{
|
||||
// This is the TEST-NET-1 IP block for use in documentation,
|
||||
// and should never actually be routable.
|
||||
netip.MustParsePrefix("192.0.2.0/24"),
|
||||
}
|
||||
impl.lb.Start(ipn.Options{
|
||||
LegacyMigrationPrefs: prefs,
|
||||
})
|
||||
impl.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress)
|
||||
|
||||
// Inject an "outbound" packet that's going to an IP address that times
|
||||
// out. We need to re-parse from a byte slice so that the internal
|
||||
// buffer in the packet.Parsed type is filled out.
|
||||
client := netip.MustParseAddr("100.101.102.103")
|
||||
destAddr := netip.MustParseAddr("192.0.2.1")
|
||||
|
||||
// Helpers
|
||||
mustInjectPacket := func() {
|
||||
pkt := tcp4syn(t, client, destAddr, 1234, 4567)
|
||||
var parsed packet.Parsed
|
||||
parsed.Decode(pkt)
|
||||
|
||||
// When injecting this packet, we want the outcome to be "drop
|
||||
// silently", which indicates that netstack is processing the
|
||||
// packet and not delivering it to the host system.
|
||||
if resp := impl.injectInbound(&parsed, impl.tundev); resp != filter.DropSilently {
|
||||
t.Fatalf("got filter outcome %v, want filter.DropSilently", resp)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
waitPacket := func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timed out waiting for connection")
|
||||
case <-gotConn:
|
||||
t.Logf("got connection in progress")
|
||||
}
|
||||
}
|
||||
|
||||
// Inject the packet to start the TCP forward and wait until we have an
|
||||
// in-flight outgoing connection.
|
||||
mustInjectPacket()
|
||||
waitPacket()
|
||||
|
||||
// Verify that we now have a single in-flight address in our map.
|
||||
impl.mu.Lock()
|
||||
inFlight := maps.Clone(impl.connsInFlightByClient)
|
||||
impl.mu.Unlock()
|
||||
|
||||
if got, ok := inFlight[client]; !ok || got != 1 {
|
||||
t.Errorf("expected 1 in-flight connection for %v, got: %v", client, inFlight)
|
||||
}
|
||||
|
||||
metrics := impl.ExpVar().(*metrics.Set)
|
||||
|
||||
// One client should have reached the limit at this point.
|
||||
if v := metrics.Get("gauge_tcp_forward_in_flight_per_client_limit_reached").String(); v != "1" {
|
||||
t.Errorf("got limit reached expvar metric=%s, want 1", v)
|
||||
}
|
||||
|
||||
// Inject another packet, and verify that we've incremented our
|
||||
// "dropped" metrics since this will have been dropped.
|
||||
mustInjectPacket()
|
||||
|
||||
// expvar metric
|
||||
const metricName = "counter_tcp_forward_max_in_flight_per_client_drop"
|
||||
if v := metrics.Get(metricName).String(); v != "1" {
|
||||
t.Errorf("got expvar metric %q=%s, want 1", metricName, v)
|
||||
}
|
||||
|
||||
// client metric
|
||||
if v := metricPerClientForwardLimit.Value(); v != 1 {
|
||||
t.Errorf("got clientmetric limit metric=%d, want 1", v)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user