tsweb: propagate RequestID via context and entire request
The recent addition of RequestID was only populated if the HTTP Request had returned an error. This meant that the underlying handler has no access to this request id and any logs it may have emitted were impossible to correlate to that request id. Therefore, this PR adds a middleware to generate request ids and pass them through the request context. The tsweb.StdHandler automatically populates this request id if the middleware is being used. Finally, inner handlers can use the context to retrieve that same request id and use it so that all logs and events can be correlated. Updates #2549 Signed-off-by: Marwan Sulaiman <marwan@tailscale.com>
This commit is contained in:

committed by
Marwan Sulaiman

parent
c27aa9e7ff
commit
b819f66eb1
@ -67,20 +67,17 @@ func TestStdHandler(t *testing.T) {
|
||||
bgCtx = context.Background()
|
||||
// canceledCtx, cancel = context.WithCancel(bgCtx)
|
||||
startTime = time.Unix(1687870000, 1234)
|
||||
|
||||
setExampleRequestID = func(_ *http.Request) RequestID { return exampleRequestID }
|
||||
)
|
||||
// cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rh ReturnHandler
|
||||
r *http.Request
|
||||
errHandler ErrorHandlerFunc
|
||||
generateRequestID func(*http.Request) RequestID
|
||||
wantCode int
|
||||
wantLog AccessLogRecord
|
||||
wantBody string
|
||||
name string
|
||||
rh ReturnHandler
|
||||
r *http.Request
|
||||
errHandler ErrorHandlerFunc
|
||||
wantCode int
|
||||
wantLog AccessLogRecord
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "handler returns 200",
|
||||
@ -100,11 +97,10 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns 200 with request ID",
|
||||
rh: handlerCode(200),
|
||||
r: req(bgCtx, "http://example.com/"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 200,
|
||||
name: "handler returns 200 with request ID",
|
||||
rh: handlerCode(200),
|
||||
r: req(bgCtx, "http://example.com/"),
|
||||
wantCode: 200,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
@ -134,11 +130,10 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns 404 with request ID",
|
||||
rh: handlerCode(404),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 404,
|
||||
name: "handler returns 404 with request ID",
|
||||
rh: handlerCode(404),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
wantCode: 404,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
@ -169,11 +164,10 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns 404 via HTTPError with request ID",
|
||||
rh: handlerErr(0, Error(404, "not found", testErr)),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 404,
|
||||
name: "handler returns 404 via HTTPError with request ID",
|
||||
rh: handlerErr(0, Error(404, "not found", testErr)),
|
||||
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
|
||||
wantCode: 404,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
@ -207,11 +201,10 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns 404 with request ID and nil child error",
|
||||
rh: handlerErr(0, Error(404, "not found", nil)),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 404,
|
||||
name: "handler returns 404 with request ID and nil child error",
|
||||
rh: handlerErr(0, Error(404, "not found", nil)),
|
||||
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
|
||||
wantCode: 404,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
@ -245,11 +238,10 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns user-visible error with request ID",
|
||||
rh: handlerErr(0, vizerror.New("visible error")),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 500,
|
||||
name: "handler returns user-visible error with request ID",
|
||||
rh: handlerErr(0, vizerror.New("visible error")),
|
||||
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
|
||||
wantCode: 500,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
@ -283,11 +275,10 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns user-visible error wrapped by private error with request ID",
|
||||
rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 500,
|
||||
name: "handler returns user-visible error wrapped by private error with request ID",
|
||||
rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
|
||||
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
|
||||
wantCode: 500,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
@ -321,11 +312,10 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns generic error with request ID",
|
||||
rh: handlerErr(0, testErr),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 500,
|
||||
name: "handler returns generic error with request ID",
|
||||
rh: handlerErr(0, testErr),
|
||||
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
|
||||
wantCode: 500,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
@ -358,11 +348,10 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns error after writing response with request ID",
|
||||
rh: handlerErr(200, testErr),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 200,
|
||||
name: "handler returns error after writing response with request ID",
|
||||
rh: handlerErr(200, testErr),
|
||||
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
|
||||
wantCode: 200,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
@ -455,13 +444,13 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
|
||||
{
|
||||
name: "error handler gets run with request ID",
|
||||
rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
|
||||
r: req(bgCtx, "http://example.com/"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 200,
|
||||
name: "error handler gets run with request ID",
|
||||
rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
|
||||
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/"),
|
||||
wantCode: 200,
|
||||
errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
|
||||
http.Error(w, fmt.Sprintf("%s with request ID %s", e.Msg, e.RequestID), 200)
|
||||
requestID := RequestIDFromContext(r.Context())
|
||||
http.Error(w, fmt.Sprintf("%s with request ID %s", e.Msg, requestID), 200)
|
||||
},
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
@ -477,37 +466,6 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
wantBody: "not found with request ID " + exampleRequestID + "\n",
|
||||
},
|
||||
|
||||
{
|
||||
name: "request ID can use information from request",
|
||||
rh: handlerErr(0, Error(400, "bad request", nil)),
|
||||
r: func() *http.Request {
|
||||
r := req(bgCtx, "http://example.com/")
|
||||
r.AddCookie(&http.Cookie{Name: "want_request_id", Value: "asdf1234"})
|
||||
return r
|
||||
}(),
|
||||
generateRequestID: func(r *http.Request) RequestID {
|
||||
c, _ := r.Cookie("want_request_id")
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return RequestID(c.Value)
|
||||
},
|
||||
wantCode: 400,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
TLS: false,
|
||||
Host: "example.com",
|
||||
RequestURI: "/",
|
||||
Method: "GET",
|
||||
Code: 400,
|
||||
Err: "bad request",
|
||||
RequestID: "asdf1234",
|
||||
},
|
||||
wantBody: "bad request\nasdf1234\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
@ -526,7 +484,7 @@ func TestStdHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
rec := noopHijacker{httptest.NewRecorder(), false}
|
||||
h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, GenerateRequestID: test.generateRequestID, OnError: test.errHandler})
|
||||
h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, OnError: test.errHandler})
|
||||
h.ServeHTTP(&rec, test.r)
|
||||
res := rec.Result()
|
||||
if res.StatusCode != test.wantCode {
|
||||
|
Reference in New Issue
Block a user