diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 42563972e..3542126ed 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -755,7 +755,7 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked() error { } } - if !deepprint.UpdateHash(&e.lastEngineSigTrim, min) { + if !deepprint.UpdateHash(&e.lastEngineSigTrim, min, trimmedDisco, trackDisco, trackIPs) { // No changes return nil } diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index 3c0c7cb3d..6c7e05e99 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -7,11 +7,15 @@ import ( "bytes" "fmt" + "reflect" "testing" "time" + "github.com/tailscale/wireguard-go/wgcfg" + "go4.org/mem" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/wgengine/router" "tailscale.com/wgengine/tstun" ) @@ -77,3 +81,65 @@ func TestNoteReceiveActivity(t *testing.T) { t.Fatalf("didn't get expected reconfig") } } + +func TestUserspaceEngineReconfig(t *testing.T) { + e, err := NewFakeUserspaceEngine(t.Logf, 0) + if err != nil { + t.Fatal(err) + } + defer e.Close() + ue := e.(*userspaceEngine) + + routerCfg := &router.Config{} + + for _, discoHex := range []string{ + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + } { + cfg := &wgcfg.Config{ + Peers: []wgcfg.Peer{ + { + AllowedIPs: []wgcfg.CIDR{ + {IP: wgcfg.IPv4(100, 100, 99, 1), Mask: 32}, + }, + Endpoints: []wgcfg.Endpoint{ + { + Host: discoHex + ".disco.tailscale", + Port: 12345, + }, + }, + }, + }, + } + + err = e.Reconfig(cfg, routerCfg) + if err != nil { + t.Fatal(err) + } + + wantRecvAt := map[tailcfg.DiscoKey]time.Time{ + dkFromHex(discoHex): time.Time{}, + } + if got := ue.recvActivityAt; !reflect.DeepEqual(got, wantRecvAt) { + t.Errorf("wrong recvActivityAt\n got: %v\nwant: %v\n", got, wantRecvAt) + } + + wantTrimmedDisco := map[tailcfg.DiscoKey]bool{ + dkFromHex(discoHex): true, + } + if got := ue.trimmedDisco; !reflect.DeepEqual(got, wantTrimmedDisco) { + t.Errorf("wrong wantTrimmedDisco\n got: %v\nwant: %v\n", got, wantTrimmedDisco) + } + } +} + +func dkFromHex(hex string) tailcfg.DiscoKey { + if len(hex) != 64 { + panic(fmt.Sprintf("%q is len %d; want 64", hex, len(hex))) + } + k, err := key.NewPublicFromHexMem(mem.S(hex[:64])) + if err != nil { + panic(fmt.Sprintf("%q is not hex: %v", hex, err)) + } + return tailcfg.DiscoKey(k) +}