From eced054796d194845857e2df7764e6e07fa8ee1e Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Fri, 20 Oct 2023 12:04:00 +0100 Subject: [PATCH] ipn/ipnlocal: close connections for removed proxy transports (#9884) Ensure that when a userspace proxy config is reloaded, connections for any removed proxies are safely closed Updates tailscale/tailscale#9725 Signed-off-by: Irbe Krumina --- ipn/ipnlocal/local.go | 5 +- ipn/ipnlocal/serve.go | 63 ++++++++++----- ipn/ipnlocal/serve_test.go | 157 +++++++++++++++++++++++++++++-------- 3 files changed, 170 insertions(+), 55 deletions(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 96bb69a6f..c01eff80e 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -14,7 +14,6 @@ import ( "maps" "net" "net/http" - "net/http/httputil" "net/netip" "net/url" "os" @@ -268,7 +267,7 @@ type LocalBackend struct { activeWatchSessions set.Set[string] // of WatchIPN SessionID serveListeners map[netip.AddrPort]*serveListener // addrPort => serveListener - serveProxyHandlers sync.Map // string (HTTPHandler.Proxy) => *httputil.ReverseProxy + serveProxyHandlers sync.Map // string (HTTPHandler.Proxy) => *reverseProxy // statusLock must be held before calling statusChanged.Wait() or // statusChanged.Broadcast(). @@ -4432,8 +4431,8 @@ func (b *LocalBackend) setServeProxyHandlersLocked() { backend := key.(string) if !backends[backend] { b.logf("serve: closing idle connections to %s", backend) - value.(*httputil.ReverseProxy).Transport.(*http.Transport).CloseIdleConnections() b.serveProxyHandlers.Delete(backend) + value.(*reverseProxy).close() } return true }) diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index d3fc25d80..9b218a71f 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -23,6 +23,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "golang.org/x/net/http2" @@ -564,31 +565,52 @@ func (b *LocalBackend) proxyHandlerForBackend(backend string) (http.Handler, err // has application/grpc content type header, the connection will be over h2c. // Otherwise standard Go http transport will be used. type reverseProxy struct { - logf logger.Logf - url *url.URL - insecure bool - backend string - lb *LocalBackend - // transport for non-h2c backends - httpTransport lazy.SyncValue[http.RoundTripper] - // transport for h2c backends - h2cTransport lazy.SyncValue[http.RoundTripper] + logf logger.Logf + url *url.URL + // insecure tracks whether the connection to an https backend should be + // insecure (i.e because we cannot verify its CA). + insecure bool + backend string + lb *LocalBackend + httpTransport lazy.SyncValue[*http.Transport] // transport for non-h2c backends + h2cTransport lazy.SyncValue[*http2.Transport] // transport for h2c backends + // closed tracks whether proxy is closed/currently closing. + closed atomic.Bool +} + +// close ensures that any open backend connections get closed. +func (rp *reverseProxy) close() { + rp.closed.Store(true) + if h2cT := rp.h2cTransport.Get(func() *http2.Transport { + return nil + }); h2cT != nil { + h2cT.CloseIdleConnections() + } + if httpTransport := rp.httpTransport.Get(func() *http.Transport { + return nil + }); httpTransport != nil { + httpTransport.CloseIdleConnections() + } } func (rp *reverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if closed := rp.closed.Load(); closed { + rp.logf("received a request for a proxy that's being closed or has been closed") + http.Error(w, "proxy is closed", http.StatusServiceUnavailable) + return + } p := &httputil.ReverseProxy{Rewrite: func(r *httputil.ProxyRequest) { r.SetURL(rp.url) r.Out.Host = r.In.Host addProxyForwardedHeaders(r) rp.lb.addTailscaleIdentityHeaders(r) - }, - } + }} // There is no way to autodetect h2c as per RFC 9113 // https://datatracker.ietf.org/doc/html/rfc9113#name-starting-http-2. // However, we assume that http:// proxy prefix in combination with the // protoccol being HTTP/2 is sufficient to detect h2c for our needs. Only use this for - // gRPC to fix a known problem pf plaintext gRPC backends + // gRPC to fix a known problem of plaintext gRPC backends if rp.shouldProxyViaH2C(r) { rp.logf("received a proxy request for plaintext gRPC") p.Transport = rp.getH2CTransport() @@ -596,13 +618,12 @@ func (rp *reverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { p.Transport = rp.getTransport() } p.ServeHTTP(w, r) - } -// getTransport gets transport for http backends. Transport gets created lazily -// at most once. -func (rp *reverseProxy) getTransport() http.RoundTripper { - return rp.httpTransport.Get(func() http.RoundTripper { +// getTransport returns the Transport used for regular (non-GRPC) requests +// to the backend. The Transport gets created lazily, at most once. +func (rp *reverseProxy) getTransport() *http.Transport { + return rp.httpTransport.Get(func() *http.Transport { return &http.Transport{ DialContext: rp.lb.dialer.SystemDial, TLSClientConfig: &tls.Config{ @@ -618,10 +639,10 @@ func (rp *reverseProxy) getTransport() http.RoundTripper { }) } -// getH2CTranport gets transport for h2c backends. Creates it lazily at most -// once. -func (rp *reverseProxy) getH2CTransport() http.RoundTripper { - return rp.h2cTransport.Get(func() http.RoundTripper { +// getH2CTransport returns the Transport used for GRPC requests to the backend. +// The Transport gets created lazily, at most once. +func (rp *reverseProxy) getH2CTransport() *http2.Transport { + return rp.h2cTransport.Get(func() *http2.Transport { return &http2.Transport{ AllowHTTP: true, DialTLSContext: func(ctx context.Context, network string, addr string, _ *tls.Config) (net.Conn, error) { diff --git a/ipn/ipnlocal/serve_test.go b/ipn/ipnlocal/serve_test.go index a561c210b..a2918eb78 100644 --- a/ipn/ipnlocal/serve_test.go +++ b/ipn/ipnlocal/serve_test.go @@ -18,6 +18,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "strings" "testing" "time" @@ -446,6 +447,119 @@ func TestServeHTTPProxy(t *testing.T) { } } +func Test_reverseProxyConfiguration(t *testing.T) { + b := newTestBackend(t) + type test struct { + backend string + path string + // set to false to test that a proxy has been removed + shouldExist bool + wantsInsecure bool + wantsURL url.URL + } + runner := func(name string, tests []test) { + t.Logf("running tests for %s", name) + host := ipn.HostPort("http://example.ts.net:80") + conf := &ipn.ServeConfig{ + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + host: {Handlers: map[string]*ipn.HTTPHandler{}}, + }, + } + for _, tt := range tests { + if tt.shouldExist { + conf.Web[host].Handlers[tt.path] = &ipn.HTTPHandler{Proxy: tt.backend} + } + } + if err := b.setServeConfigLocked(conf, ""); err != nil { + t.Fatal(err) + } + // test that reverseproxies have been set up as expected + for _, tt := range tests { + rp, ok := b.serveProxyHandlers.Load(tt.backend) + if !tt.shouldExist && ok { + t.Errorf("proxy for backend %s should not exist, but it does", tt.backend) + } + if !tt.shouldExist { + continue + } + parsedRp, ok := rp.(*reverseProxy) + if !ok { + t.Errorf("proxy for backend %q is not a reverseproxy", tt.backend) + } + if parsedRp.insecure != tt.wantsInsecure { + t.Errorf("proxy for backend %q should be insecure: %v got insecure: %v", tt.backend, tt.wantsInsecure, parsedRp.insecure) + } + if !reflect.DeepEqual(*parsedRp.url, tt.wantsURL) { + t.Errorf("proxy for backend %q should have URL %#+v, got URL %+#v", tt.backend, &tt.wantsURL, parsedRp.url) + } + if tt.backend != parsedRp.backend { + t.Errorf("proxy for backend %q should have backend %q got %q", tt.backend, tt.backend, parsedRp.backend) + } + } + } + + // configure local backend with some proxy backends + runner("initial proxy configs", []test{ + { + backend: "http://example.com/docs", + path: "/example", + shouldExist: true, + wantsInsecure: false, + wantsURL: mustCreateURL(t, "http://example.com/docs"), + }, + { + backend: "https://example1.com", + path: "/example1", + shouldExist: true, + wantsInsecure: false, + wantsURL: mustCreateURL(t, "https://example1.com"), + }, + { + backend: "https+insecure://example2.com", + path: "/example2", + shouldExist: true, + wantsInsecure: true, + wantsURL: mustCreateURL(t, "https://example2.com"), + }, + }) + + // reconfigure the local backend with different proxies + runner("reloaded proxy configs", []test{ + { + backend: "http://example.com/docs", + path: "/example", + shouldExist: true, + wantsInsecure: false, + wantsURL: mustCreateURL(t, "http://example.com/docs"), + }, + { + backend: "https://example1.com", + shouldExist: false, + }, + { + backend: "https+insecure://example2.com", + shouldExist: false, + }, + { + backend: "https+insecure://example3.com", + path: "/example3", + shouldExist: true, + wantsInsecure: true, + wantsURL: mustCreateURL(t, "https://example3.com"), + }, + }) + +} + +func mustCreateURL(t *testing.T, u string) url.URL { + t.Helper() + uParsed, err := url.Parse(u) + if err != nil { + t.Fatalf("failed parsing url: %v", err) + } + return *uParsed +} + func newTestBackend(t *testing.T) *LocalBackend { sys := &tsd.System{} e, err := wgengine.NewUserspaceEngine(t.Logf, wgengine.Config{SetSubsystem: sys.Set}) @@ -589,40 +703,21 @@ func TestServeFileOrDirectory(t *testing.T) { } func Test_isGRPCContentType(t *testing.T) { - tests := map[string]struct { + tests := []struct { contentType string want bool }{ - "application/grpc": { - contentType: "application/grpc", - want: true, - }, - "application/grpc;": { - contentType: "application/grpc;", - want: true, - }, - "application/grpc+": { - contentType: "application/grpc+", - want: true, - }, - "application/grpcfoobar": { - contentType: "application/grpcfoobar", - }, - "application/text": { - contentType: "application/text", - }, - "foobar": { - contentType: "foobar", - }, - "no content type": { - contentType: "", - }, + {contentType: "application/grpc", want: true}, + {contentType: "application/grpc;", want: true}, + {contentType: "application/grpc+", want: true}, + {contentType: "application/grpcfoobar"}, + {contentType: "application/text"}, + {contentType: "foobar"}, + {contentType: ""}, } - for name, scenario := range tests { - t.Run(name, func(t *testing.T) { - if got := isGRPCContentType(scenario.contentType); got != scenario.want { - t.Errorf("test case %s failed, isGRPCContentType() = %v, want %v", name, got, scenario.want) - } - }) + for _, tt := range tests { + if got := isGRPCContentType(tt.contentType); got != tt.want { + t.Errorf("isGRPCContentType(%q) = %v, want %v", tt.contentType, got, tt.want) + } } }