tsweb: rewrite JSONHandler without using reflect (#684)
Closes #656 #657 Signed-off-by: Zijie Lu <zijie@tailscale.com>
This commit is contained in:
@ -5,9 +5,8 @@
|
||||
package tsweb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@ -26,7 +25,7 @@ type Response struct {
|
||||
}
|
||||
|
||||
func TestNewJSONHandler(t *testing.T) {
|
||||
checkStatus := func(w *httptest.ResponseRecorder, status string) *Response {
|
||||
checkStatus := func(w *httptest.ResponseRecorder, status string, code int) *Response {
|
||||
d := &Response{
|
||||
Data: &Data{},
|
||||
}
|
||||
@ -44,6 +43,10 @@ func TestNewJSONHandler(t *testing.T) {
|
||||
t.Fatalf("wrong status: %s %s", d.Status, status)
|
||||
}
|
||||
|
||||
if w.Code != code {
|
||||
t.Fatalf("wrong status code: %d %d", w.Code, code)
|
||||
}
|
||||
|
||||
if w.Header().Get("Content-Type") != "application/json" {
|
||||
t.Fatalf("wrong content type: %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
@ -51,163 +54,139 @@ func TestNewJSONHandler(t *testing.T) {
|
||||
return d
|
||||
}
|
||||
|
||||
// 2 1
|
||||
h21 := JSONHandler(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
h21 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
return http.StatusOK, nil, nil
|
||||
})
|
||||
|
||||
t.Run("2 1 simple", func(t *testing.T) {
|
||||
t.Run("200 simple", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
h21.ServeHTTP(w, r)
|
||||
checkStatus(w, "success")
|
||||
checkStatus(w, "success", http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("2 1 HTTPError", func(t *testing.T) {
|
||||
h := JSONHandler(func(w http.ResponseWriter, r *http.Request) HTTPError {
|
||||
return Error(http.StatusForbidden, "forbidden", nil)
|
||||
t.Run("403 HTTPError", func(t *testing.T) {
|
||||
h := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
return http.StatusForbidden, nil, fmt.Errorf("forbidden")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("wrong code: %d %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
checkStatus(w, "error", http.StatusForbidden)
|
||||
})
|
||||
|
||||
// 2 2
|
||||
h22 := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) {
|
||||
return &Data{Name: "tailscale"}, nil
|
||||
h22 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
return http.StatusOK, &Data{Name: "tailscale"}, nil
|
||||
})
|
||||
t.Run("2 2 get data", func(t *testing.T) {
|
||||
|
||||
t.Run("200 get data", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
h22.ServeHTTP(w, r)
|
||||
checkStatus(w, "success")
|
||||
checkStatus(w, "success", http.StatusOK)
|
||||
})
|
||||
|
||||
// 3 1
|
||||
h31 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) error {
|
||||
if d.Name == "" {
|
||||
return errors.New("name is empty")
|
||||
h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
body := new(Data)
|
||||
if err := json.NewDecoder(r.Body).Decode(body); err != nil {
|
||||
return http.StatusBadRequest, nil, err
|
||||
}
|
||||
|
||||
return nil
|
||||
if body.Name == "" {
|
||||
return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "name is empty", nil)
|
||||
}
|
||||
|
||||
return http.StatusOK, nil, nil
|
||||
})
|
||||
t.Run("3 1 post data", func(t *testing.T) {
|
||||
t.Run("200 post data", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`))
|
||||
h31.ServeHTTP(w, r)
|
||||
checkStatus(w, "success")
|
||||
checkStatus(w, "success", http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("3 1 bad json", func(t *testing.T) {
|
||||
t.Run("400 bad json", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{`))
|
||||
h31.ServeHTTP(w, r)
|
||||
checkStatus(w, "error")
|
||||
checkStatus(w, "error", http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("3 1 post data error", func(t *testing.T) {
|
||||
t.Run("400 post data error", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
||||
h31.ServeHTTP(w, r)
|
||||
resp := checkStatus(w, "error")
|
||||
resp := checkStatus(w, "error", http.StatusBadRequest)
|
||||
if resp.Error != "name is empty" {
|
||||
t.Fatalf("wrong error")
|
||||
}
|
||||
})
|
||||
|
||||
// 3 2
|
||||
h32 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) (*Data, error) {
|
||||
if d.Price == 0 {
|
||||
return nil, errors.New("price is empty")
|
||||
h32 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
body := new(Data)
|
||||
if err := json.NewDecoder(r.Body).Decode(body); err != nil {
|
||||
return http.StatusBadRequest, nil, err
|
||||
}
|
||||
if body.Name == "root" {
|
||||
return http.StatusInternalServerError, nil, fmt.Errorf("invalid name")
|
||||
}
|
||||
if body.Price == 0 {
|
||||
return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "price is empty", nil)
|
||||
}
|
||||
|
||||
return &Data{Price: d.Price * 2}, nil
|
||||
return http.StatusOK, &Data{Price: body.Price * 2}, nil
|
||||
})
|
||||
t.Run("3 2 post data", func(t *testing.T) {
|
||||
|
||||
t.Run("200 post data", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
|
||||
h32.ServeHTTP(w, r)
|
||||
resp := checkStatus(w, "success")
|
||||
resp := checkStatus(w, "success", http.StatusOK)
|
||||
t.Log(resp.Data)
|
||||
if resp.Data.Price != 20 {
|
||||
t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("3 2 post data error", func(t *testing.T) {
|
||||
t.Run("400 post data error", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
||||
h32.ServeHTTP(w, r)
|
||||
resp := checkStatus(w, "error")
|
||||
resp := checkStatus(w, "error", http.StatusBadRequest)
|
||||
if resp.Error != "price is empty" {
|
||||
t.Fatalf("wrong error")
|
||||
}
|
||||
})
|
||||
|
||||
// fn check
|
||||
shouldPanic := func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
t.Fatalf("should panic")
|
||||
}
|
||||
t.Log(r)
|
||||
}
|
||||
|
||||
t.Run("2 0 panic", func(t *testing.T) {
|
||||
defer shouldPanic()
|
||||
JSONHandler(func(w http.ResponseWriter, r *http.Request) {})
|
||||
})
|
||||
|
||||
t.Run("2 1 panic return value", func(t *testing.T) {
|
||||
defer shouldPanic()
|
||||
JSONHandler(func(w http.ResponseWriter, r *http.Request) string {
|
||||
return ""
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("2 1 panic arguments", func(t *testing.T) {
|
||||
defer shouldPanic()
|
||||
JSONHandler(func(r *http.Request, w http.ResponseWriter) error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("3 1 panic arguments", func(t *testing.T) {
|
||||
defer shouldPanic()
|
||||
JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("3 2 panic return value", func(t *testing.T) {
|
||||
defer shouldPanic()
|
||||
//lint:ignore ST1008 intentional
|
||||
JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) (error, string) {
|
||||
return nil, "panic"
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("2 2 forbidden", func(t *testing.T) {
|
||||
code := http.StatusForbidden
|
||||
body := []byte("forbidden")
|
||||
h := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) {
|
||||
w.WriteHeader(code)
|
||||
w.Write(body)
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
t.Run("500 internal server error", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("wrong code: %d %d", w.Code, code)
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`))
|
||||
h32.ServeHTTP(w, r)
|
||||
resp := checkStatus(w, "error", http.StatusInternalServerError)
|
||||
if resp.Error != "internal server error" {
|
||||
t.Fatalf("wrong error")
|
||||
}
|
||||
if !bytes.Equal(w.Body.Bytes(), []byte("forbidden")) {
|
||||
t.Fatalf("wrong body: %s %s", w.Body.Bytes(), body)
|
||||
})
|
||||
|
||||
t.Run("500 misuse", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", nil)
|
||||
JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
return http.StatusOK, make(chan int), nil
|
||||
}).ServeHTTP(w, r)
|
||||
resp := checkStatus(w, "error", http.StatusInternalServerError)
|
||||
if resp.Error != "json marshal error" {
|
||||
t.Fatalf("wrong error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("500 empty status code", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", nil)
|
||||
JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) {
|
||||
return
|
||||
}).ServeHTTP(w, r)
|
||||
checkStatus(w, "error", http.StatusInternalServerError)
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user