From 5d09649b0b68ed89a313e540c777f1466dfabb86 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 19 Jul 2024 15:29:25 -0500 Subject: [PATCH] types/lazy: add (*SyncValue[T]).SetForTest method It is sometimes necessary to change a global lazy.SyncValue for the duration of a test. This PR adds a (*SyncValue[T]).SetForTest method to facilitate that. Updates #12687 Signed-off-by: Nick Khyl --- types/lazy/lazy.go | 31 +++++++ types/lazy/sync_test.go | 192 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+) diff --git a/types/lazy/lazy.go b/types/lazy/lazy.go index 755b5ca6f..8bd55bdf6 100644 --- a/types/lazy/lazy.go +++ b/types/lazy/lazy.go @@ -154,3 +154,34 @@ func SyncFuncErr[T any](fill func() (T, error)) func() (T, error) { return v, err } } + +// TB is a subset of testing.TB that we use to set up test helpers. +// It's defined here to avoid pulling in the testing package. +type TB interface { + Helper() + Cleanup(func()) +} + +// SetForTest sets z's value and error. +// It's used in tests only and reverts z's state back when tb and all its +// subtests complete. +// It is not safe for concurrent use and must not be called concurrently with +// any SyncValue methods, including another call to itself. +func (z *SyncValue[T]) SetForTest(tb TB, val T, err error) { + tb.Helper() + + z.once.Do(func() {}) + oldErr, oldVal := z.err.Load(), z.v + + z.v = val + if err != nil { + z.err.Store(ptr.To(err)) + } else { + z.err.Store(nilErrPtr) + } + + tb.Cleanup(func() { + z.v = oldVal + z.err.Store(oldErr) + }) +} diff --git a/types/lazy/sync_test.go b/types/lazy/sync_test.go index ab3ed427d..8fdf9e76f 100644 --- a/types/lazy/sync_test.go +++ b/types/lazy/sync_test.go @@ -8,6 +8,8 @@ "fmt" "sync" "testing" + + "tailscale.com/types/opt" ) func TestSyncValue(t *testing.T) { @@ -147,6 +149,196 @@ func TestSyncValueConcurrent(t *testing.T) { wg.Wait() } +func TestSyncValueSetForTest(t *testing.T) { + testErr := errors.New("boom") + tests := []struct { + name string + initValue opt.Value[int] + initErr opt.Value[error] + setForTestValue int + setForTestErr error + getValue int + getErr opt.Value[error] + wantValue int + wantErr error + routines int + }{ + { + name: "GetOk", + setForTestValue: 42, + getValue: 8, + wantValue: 42, + }, + { + name: "GetOk/WithInit", + initValue: opt.ValueOf(4), + setForTestValue: 42, + getValue: 8, + wantValue: 42, + }, + { + name: "GetOk/WithInitErr", + initValue: opt.ValueOf(4), + initErr: opt.ValueOf(errors.New("blast")), + setForTestValue: 42, + getValue: 8, + wantValue: 42, + }, + { + name: "GetErr", + setForTestValue: 42, + setForTestErr: testErr, + getValue: 8, + getErr: opt.ValueOf(errors.New("ka-boom")), + wantValue: 42, + wantErr: testErr, + }, + { + name: "GetErr/NilError", + setForTestValue: 42, + setForTestErr: nil, + getValue: 8, + getErr: opt.ValueOf(errors.New("ka-boom")), + wantValue: 42, + wantErr: nil, + }, + { + name: "GetErr/WithInitErr", + initValue: opt.ValueOf(4), + initErr: opt.ValueOf(errors.New("blast")), + setForTestValue: 42, + setForTestErr: testErr, + getValue: 8, + getErr: opt.ValueOf(errors.New("ka-boom")), + wantValue: 42, + wantErr: testErr, + }, + { + name: "Concurrent/GetOk", + setForTestValue: 42, + getValue: 8, + wantValue: 42, + routines: 10000, + }, + { + name: "Concurrent/GetOk/WithInitErr", + initValue: opt.ValueOf(4), + initErr: opt.ValueOf(errors.New("blast")), + setForTestValue: 42, + getValue: 8, + wantValue: 42, + routines: 10000, + }, + { + name: "Concurrent/GetErr", + setForTestValue: 42, + setForTestErr: testErr, + getValue: 8, + getErr: opt.ValueOf(errors.New("ka-boom")), + wantValue: 42, + wantErr: testErr, + routines: 10000, + }, + { + name: "Concurrent/GetErr/WithInitErr", + initValue: opt.ValueOf(4), + initErr: opt.ValueOf(errors.New("blast")), + setForTestValue: 42, + setForTestErr: testErr, + getValue: 8, + getErr: opt.ValueOf(errors.New("ka-boom")), + wantValue: 42, + wantErr: testErr, + routines: 10000, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var v SyncValue[int] + + // Initialize the sync value with the specified value and/or error, + // if required by the test. + if initValue, ok := tt.initValue.GetOk(); ok { + var wantInitErr, gotInitErr error + var wantInitValue, gotInitValue int + wantInitValue = initValue + if initErr, ok := tt.initErr.GetOk(); ok { + wantInitErr = initErr + gotInitValue, gotInitErr = v.GetErr(func() (int, error) { return initValue, initErr }) + } else { + gotInitValue = v.Get(func() int { return initValue }) + } + + if gotInitErr != wantInitErr { + t.Fatalf("InitErr: got %v; want %v", gotInitErr, wantInitErr) + } + if gotInitValue != wantInitValue { + t.Fatalf("InitValue: got %v; want %v", gotInitValue, wantInitValue) + } + + // Verify that SetForTest reverted the error and the value during the test cleanup. + t.Cleanup(func() { + wantCleanupValue, wantCleanupErr := wantInitValue, wantInitErr + gotCleanupValue, gotCleanupErr, ok := v.PeekErr() + if !ok { + t.Fatal("SyncValue is not set after cleanup") + } + if gotCleanupErr != wantCleanupErr { + t.Fatalf("CleanupErr: got %v; want %v", gotCleanupErr, wantCleanupErr) + } + if gotCleanupValue != wantCleanupValue { + t.Fatalf("CleanupValue: got %v; want %v", gotCleanupValue, wantCleanupValue) + } + }) + } + + // Set the test value and/or error. + v.SetForTest(t, tt.setForTestValue, tt.setForTestErr) + + // Verify that the value and/or error have been set. + // This will run on either the current goroutine + // or concurrently depending on the tt.routines value. + checkSyncValue := func() { + var gotValue int + var gotErr error + if getErr, ok := tt.getErr.GetOk(); ok { + gotValue, gotErr = v.GetErr(func() (int, error) { return tt.getValue, getErr }) + } else { + gotValue = v.Get(func() int { return tt.getValue }) + } + + if gotErr != tt.wantErr { + t.Errorf("Err: got %v; want %v", gotErr, tt.wantErr) + } + if gotValue != tt.wantValue { + t.Errorf("Value: got %v; want %v", gotValue, tt.wantValue) + } + } + + switch tt.routines { + case 0: + checkSyncValue() + default: + var wg sync.WaitGroup + wg.Add(tt.routines) + start := make(chan struct{}) + for range tt.routines { + go func() { + defer wg.Done() + // Every goroutine waits for the go signal, so that more of them + // have a chance to race on the initial Get than with sequential + // goroutine starts. + <-start + checkSyncValue() + }() + } + close(start) + wg.Wait() + } + }) + } +} + func TestSyncFunc(t *testing.T) { f := SyncFunc(fortyTwo)