diff --git a/ipn/message_test.go b/ipn/message_test.go index 9cd6a887e..4422a64ca 100644 --- a/ipn/message_test.go +++ b/ipn/message_test.go @@ -16,9 +16,7 @@ import ( func TestReadWrite(t *testing.T) { tstest.PanicOnLog() - - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) buf := bytes.Buffer{} err := WriteMsg(&buf, []byte("Test string1")) @@ -64,9 +62,7 @@ func TestReadWrite(t *testing.T) { func TestClientServer(t *testing.T) { tstest.PanicOnLog() - - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) b := &FakeBackend{} var bs *BackendServer diff --git a/portlist/portlist_test.go b/portlist/portlist_test.go index 7fde87e1b..7977303a7 100644 --- a/portlist/portlist_test.go +++ b/portlist/portlist_test.go @@ -12,8 +12,7 @@ import ( ) func TestGetList(t *testing.T) { - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) pl, err := GetList(nil) if err != nil { @@ -26,8 +25,7 @@ func TestGetList(t *testing.T) { } func TestIgnoreLocallyBoundPorts(t *testing.T) { - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { diff --git a/tstest/resource.go b/tstest/resource.go index 7be6eecd1..47828dea8 100644 --- a/tstest/resource.go +++ b/tstest/resource.go @@ -14,64 +14,30 @@ import ( "github.com/google/go-cmp/cmp" ) -type ResourceCheck struct { - startNumRoutines int - startDump string -} - -func NewResourceCheck() *ResourceCheck { - // NOTE(apenwarr): I'd rather not pre-generate a goroutine dump here. - // However, it turns out to be tricky to debug when eg. the initial - // goroutine count > the ending goroutine count, because of course - // the missing ones are not in the final dump. Also, we have to - // render the profile as a string right away, because the - // pprof.Profile object doesn't stay stable over time. Every time - // you render the string, you might get a different answer. - r := &ResourceCheck{} - r.startNumRoutines, r.startDump = goroutineDump() - return r -} - -func goroutineDump() (int, string) { - p := pprof.Lookup("goroutine") - b := &bytes.Buffer{} - p.WriteTo(b, 1) - return p.Count(), b.String() -} - -func (r *ResourceCheck) Assert(t testing.TB) { - if t.Failed() { - // Something else went wrong. - // Assume that that is responsible for the leak - // and don't pile on a bunch of extra of output. - return - } - t.Helper() - want := r.startNumRoutines - - // Some goroutines might be still exiting, so give them a chance - got := runtime.NumGoroutine() - if want != got { - _, dump := goroutineDump() +func ResourceCheck(tb testing.TB) { + startN, startStacks := goroutines() + tb.Cleanup(func() { + if tb.Failed() { + // Something else went wrong. + return + } + tb.Helper() + // Goroutines might be still exiting. for i := 0; i < 100; i++ { - got = runtime.NumGoroutine() - if want == got { - break + if runtime.NumGoroutine() <= startN { + return } time.Sleep(1 * time.Millisecond) } - - // If the count is *still* wrong, that's a failure. - if want != got { - t.Logf("goroutine diff:\n%v\n", cmp.Diff(r.startDump, dump)) - t.Logf("goroutine count: expected %d, got %d\n", want, got) - // Don't fail if there are *fewer* goroutines than - // expected. That just might be some leftover ones - // from the previous test, which are pretty hard to - // eliminate. - if want < got { - t.Fatalf("ResourceCheck: goroutine count: expected %d, got %d\n", want, got) - } - } - } + endN, endStacks := goroutines() + tb.Logf("goroutine diff:\n%v\n", cmp.Diff(startStacks, endStacks)) + tb.Fatalf("goroutine count: expected %d, got %d\n", startN, endN) + }) +} + +func goroutines() (int, []byte) { + p := pprof.Lookup("goroutine") + b := new(bytes.Buffer) + p.WriteTo(b, 1) + return p.Count(), b.Bytes() } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index dfa1f6230..998b9bb27 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -331,8 +331,7 @@ func meshStacks(logf logger.Logf, ms []*magicStack) (cleanup func()) { func TestNewConn(t *testing.T) { tstest.PanicOnLog() - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) epCh := make(chan string, 16) epFunc := func(endpoints []string) { @@ -398,8 +397,7 @@ func pickPort(t testing.TB) uint16 { func TestDerpIPConstant(t *testing.T) { tstest.PanicOnLog() - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) if DerpMagicIP != derpMagicIP.String() { t.Errorf("str %q != IP %v", DerpMagicIP, derpMagicIP) @@ -411,8 +409,7 @@ func TestDerpIPConstant(t *testing.T) { func TestPickDERPFallback(t *testing.T) { tstest.PanicOnLog() - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) c := newConn() c.derpMap = derpmap.Prod() @@ -519,8 +516,7 @@ func parseCIDR(t *testing.T, addr string) netaddr.IPPrefix { // -count=10000 to be sure. func TestDeviceStartStop(t *testing.T) { tstest.PanicOnLog() - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) conn, err := NewConn(Options{ EndpointsFunc: func(eps []string) {}, @@ -838,8 +834,7 @@ func newPinger(t *testing.T, logf logger.Logf, src, dst *magicStack) (cleanup fu // get exercised. func testActiveDiscovery(t *testing.T, d *devices) { tstest.PanicOnLog() - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) tlogf, setT := makeNestable(t) setT(t) @@ -900,8 +895,7 @@ func testActiveDiscovery(t *testing.T, d *devices) { func testTwoDevicePing(t *testing.T, d *devices) { tstest.PanicOnLog() - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) // This gets reassigned inside every test, so that the connections // all log using the "current" t.Logf function. Sigh. @@ -1145,8 +1139,7 @@ func testTwoDevicePing(t *testing.T, d *devices) { // TestAddrSet tests addrSet appendDests and updateDst. func TestAddrSet(t *testing.T) { tstest.PanicOnLog() - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) mustIPPortPtr := func(s string) *netaddr.IPPort { t.Helper() diff --git a/wgengine/tsdns/tsdns_test.go b/wgengine/tsdns/tsdns_test.go index 2eb0df479..95d32dfbb 100644 --- a/wgengine/tsdns/tsdns_test.go +++ b/wgengine/tsdns/tsdns_test.go @@ -275,8 +275,7 @@ func TestResolveReverse(t *testing.T) { } func TestDelegate(t *testing.T) { - rc := tstest.NewResourceCheck() - defer rc.Assert(t) + tstest.ResourceCheck(t) dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN)