diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index 0221dacb9..430ccd5b8 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -472,15 +472,9 @@ func (h logHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Let errorHandler tell us what error it wrote to the client. r = r.WithContext(errCallback.WithValue(ctx, func(e string) { - if ctx.Err() == context.Canceled { - msg.Code = 499 // nginx convention: Client Closed Request - msg.Err = context.Canceled.Error() - return + if msg.Err == "" { + msg.Err = e // Keep the first error. } - if msg.Err != "" { - return - } - msg.Err = e })) lw := newLogResponseWriter(h.opts.Logf, w, r) @@ -513,9 +507,14 @@ func (h logHandler) logRequest(r *http.Request, lw *loggingResponseWriter, msg A // switched protocols away from HTTP. msg.Code = http.StatusSwitchingProtocols case lw.code == 0: - // If the handler didn't write and didn't send a header, that still means 200. - // (See https://play.golang.org/p/4P7nx_Tap7p) - msg.Code = 200 + if r.Context().Err() != nil { + // We didn't write a response before the client disconnected. + msg.Code = 499 + } else { + // If the handler didn't write and didn't send a header, that still means 200. + // (See https://play.golang.org/p/4P7nx_Tap7p) + msg.Code = 200 + } default: msg.Code = lw.code } @@ -669,7 +668,9 @@ func (h errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { err = h.rh.ServeHTTPReturn(lw, r) } -func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *loggingResponseWriter, err error) (logged bool) { +func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *loggingResponseWriter, err error) bool { + var logged bool + // Extract a presentable, loggable error. var hOK bool var hErr HTTPError @@ -681,6 +682,8 @@ func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *lo } } else if v, ok := vizerror.As(err); ok { hErr = Error(http.StatusInternalServerError, v.Error(), nil) + } else if errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) { + hErr = Error(499, "", err) // Nginx convention } else { // Omit the friendly message so HTTP logs show the bare error that was // returned and we know it's not a HTTPError. @@ -699,21 +702,31 @@ func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *lo logged = true } + if r.Context().Err() != nil { + return logged + } + if lw.code != 0 { - if hOK { - lw.logf("[unexpected] handler returned HTTPError %v, but already sent a response with code %d", hErr, lw.code) + if hOK && hErr.Code != lw.code { + lw.logf("[unexpected] handler returned HTTPError %v, but already sent response with code %d", hErr, lw.code) } - return + return logged } // Set a default error message from the status code. Do this after we pass // the error back to the logger so that `return errors.New("oh")` logs as // `"err": "oh"`, not `"err": "Internal Server Error: oh"`. if hErr.Msg == "" { - hErr.Msg = http.StatusText(hErr.Code) + switch hErr.Code { + case 499: + hErr.Msg = "Client Closed Request" + default: + hErr.Msg = http.StatusText(hErr.Code) + } } // If OnError panics before a response is written, write a bare 500 back. + // OnError panics are thrown further up the stack. defer func() { if lw.code == 0 { if rec := recover(); rec != nil { diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index 18bb7e48d..2633534aa 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -22,6 +22,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "tailscale.com/metrics" "tailscale.com/tstest" + "tailscale.com/util/httpm" "tailscale.com/util/must" "tailscale.com/util/vizerror" ) @@ -693,6 +694,81 @@ func TestStdHandler_Panic(t *testing.T) { res.Body.Close() } +func TestStdHandler_Canceled(t *testing.T) { + now := time.Now() + + r := make(chan AccessLogRecord) + var e *HTTPError + handlerOpen := make(chan struct{}) + h := StdHandler( + ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + close(handlerOpen) + ctx := r.Context() + <-ctx.Done() + return ctx.Err() + }), + HandlerOptions{ + Logf: t.Logf, + Now: func() time.Time { return now }, + OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) { + e = &h + }, + OnCompletion: func(_ *http.Request, alr AccessLogRecord) { + r <- alr + }, + }, + ) + + // Create a context which gets canceled after the handler starts processing + // the request. + ctx, cancelReq := context.WithCancel(context.Background()) + go func() { + <-handlerOpen + cancelReq() + }() + + s := httptest.NewServer(h) + t.Cleanup(s.Close) + + // Send a request to our server. + req, err := http.NewRequestWithContext(ctx, httpm.GET, s.URL, nil) + if err != nil { + t.Fatalf("making request: %s", err) + } + res, err := http.DefaultClient.Do(req) + if !errors.Is(err, context.Canceled) { + t.Errorf("got error %v, want context.Canceled", err) + } + if res != nil { + t.Errorf("got response %#v, want nil", res) + } + + // Check that we got the expected log record. + got := <-r + got.Seconds = 0 + got.RemoteAddr = "" + got.Host = "" + got.UserAgent = "" + want := AccessLogRecord{ + Time: now, + Code: 499, + Method: "GET", + Err: "context canceled", + Proto: "HTTP/1.1", + RequestURI: "/", + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("AccessLogRecord wrong (-want +got)\n%s", d) + } + + // Check that we rendered no response to the client after + // logHandler.OnCompletion has been called. + if e != nil { + t.Errorf("got OnError callback with %#v, want no callback", e) + } + +} + func TestStdHandler_OnErrorPanic(t *testing.T) { var r AccessLogRecord h := StdHandler(