diff --git a/client/web/web.go b/client/web/web.go index 24fd71c8b..e4c5c68b1 100644 --- a/client/web/web.go +++ b/client/web/web.go @@ -965,6 +965,13 @@ func (s *Server) proxyRequestToLocalAPI(w http.ResponseWriter, r *http.Request) http.Error(w, "invalid request", http.StatusBadRequest) return } + if r.Method == httpm.PATCH { + // enforce that PATCH requests are always application/json + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + http.Error(w, "invalid request", http.StatusBadRequest) + return + } + } if !slices.Contains(localapiAllowlist, path) { http.Error(w, fmt.Sprintf("%s not allowed from localapi proxy", path), http.StatusForbidden) return diff --git a/client/web/web_test.go b/client/web/web_test.go index d37e3e00a..dc5670fbb 100644 --- a/client/web/web_test.go +++ b/client/web/web_test.go @@ -100,29 +100,44 @@ func TestServeAPI(t *testing.T) { s := &Server{lc: &tailscale.LocalClient{Dial: lal.Dial}} tests := []struct { - name string - reqPath string - wantResp string - wantStatus int + name string + reqMethod string + reqPath string + reqContentType string + wantResp string + wantStatus int }{{ name: "invalid_endpoint", + reqMethod: httpm.POST, reqPath: "/not-an-endpoint", wantResp: "invalid endpoint", wantStatus: http.StatusNotFound, }, { name: "not_in_localapi_allowlist", + reqMethod: httpm.POST, reqPath: "/local/v0/not-allowlisted", wantResp: "/v0/not-allowlisted not allowed from localapi proxy", wantStatus: http.StatusForbidden, }, { name: "in_localapi_allowlist", + reqMethod: httpm.POST, reqPath: "/local/v0/logout", wantResp: "success", // Successfully allowed to hit localapi. wantStatus: http.StatusOK, + }, { + name: "patch_bad_contenttype", + reqMethod: httpm.PATCH, + reqPath: "/local/v0/prefs", + reqContentType: "multipart/form-data", + wantResp: "invalid request", + wantStatus: http.StatusBadRequest, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := httptest.NewRequest("POST", "/api"+tt.reqPath, nil) + r := httptest.NewRequest(tt.reqMethod, "/api"+tt.reqPath, nil) + if tt.reqContentType != "" { + r.Header.Add("Content-Type", tt.reqContentType) + } w := httptest.NewRecorder() s.serveAPI(w, r)