net/dns: add Windows group policy notifications to the NRPT rule manager
As discussed in previous PRs, we can register for notifications when group policies are updated and act accordingly. This patch changes nrptRuleDatabase to receive notifications that group policy has changed and automatically move our NRPT rules between the local and group policy subkeys as needed. Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
@ -5,6 +5,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
@ -20,11 +21,6 @@ import (
|
||||
|
||||
const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}"
|
||||
|
||||
var (
|
||||
procRegisterGPNotification = libUserenv.NewProc("RegisterGPNotification")
|
||||
procUnregisterGPNotification = libUserenv.NewProc("UnregisterGPNotification")
|
||||
)
|
||||
|
||||
func TestManagerWindowsLocal(t *testing.T) {
|
||||
if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
|
||||
t.Skipf("test requires running as elevated user on Windows 10+")
|
||||
@ -53,6 +49,121 @@ func TestManagerWindowsGP(t *testing.T) {
|
||||
runTest(t, false)
|
||||
}
|
||||
|
||||
func TestManagerWindowsGPMove(t *testing.T) {
|
||||
if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
|
||||
t.Skipf("test requires running as elevated user on Windows 10+")
|
||||
}
|
||||
|
||||
checkGPNotificationsWork(t)
|
||||
|
||||
logf := func(format string, args ...any) {
|
||||
t.Logf(format, args...)
|
||||
}
|
||||
|
||||
fakeInterface, err := windows.GenerateGUID()
|
||||
if err != nil {
|
||||
t.Fatalf("windows.GenerateGUID: %v\n", err)
|
||||
}
|
||||
|
||||
delIfKey, err := createFakeInterfaceKey(t, fakeInterface)
|
||||
if err != nil {
|
||||
t.Fatalf("createFakeInterfaceKey: %v\n", err)
|
||||
}
|
||||
defer delIfKey()
|
||||
|
||||
cfg, err := NewOSConfigurator(logf, fakeInterface.String())
|
||||
if err != nil {
|
||||
t.Fatalf("NewOSConfigurator: %v\n", err)
|
||||
}
|
||||
mgr := cfg.(windowsManager)
|
||||
defer mgr.Close()
|
||||
|
||||
usingGP := mgr.nrptDB.writeAsGP
|
||||
if usingGP {
|
||||
t.Fatalf("usingGP %v, want %v\n", usingGP, false)
|
||||
}
|
||||
|
||||
regWatcher, err := newRegKeyWatcher()
|
||||
if err != nil {
|
||||
t.Fatalf("newRegKeyWatcher error %v\n", err)
|
||||
}
|
||||
|
||||
// Upon initialization of cfg, we should not have any NRPT rules
|
||||
ensureNoRules(t)
|
||||
|
||||
resolvers := []netaddr.IP{netaddr.MustParseIP("1.1.1.1")}
|
||||
domains := genRandomSubdomains(t, 1)
|
||||
|
||||
// 1. Populate local NRPT
|
||||
err = mgr.setSplitDNS(resolvers, domains)
|
||||
if err != nil {
|
||||
t.Fatalf("setSplitDNS: %v\n", err)
|
||||
}
|
||||
|
||||
t.Logf("Validating that local NRPT is populated...\n")
|
||||
validateRegistry(t, nrptBaseLocal, domains)
|
||||
ensureNoRulesInSubkey(t, nrptBaseGP)
|
||||
|
||||
// 2. Create fake GP key and refresh
|
||||
t.Logf("Creating fake group policy key and refreshing...\n")
|
||||
err = createFakeGPKey()
|
||||
if err != nil {
|
||||
t.Fatalf("createFakeGPKey: %v\n", err)
|
||||
}
|
||||
|
||||
err = regWatcher.watch()
|
||||
if err != nil {
|
||||
t.Fatalf("regWatcher.watch: %v\n", err)
|
||||
}
|
||||
|
||||
err = testDoRefresh()
|
||||
if err != nil {
|
||||
t.Fatalf("testDoRefresh: %v\n", err)
|
||||
}
|
||||
|
||||
err = regWatcher.wait()
|
||||
if err != nil {
|
||||
t.Fatalf("regWatcher.wait: %v\n", err)
|
||||
}
|
||||
|
||||
// 3. Check that local NRPT is empty and GP is populated
|
||||
t.Logf("Validating that group policy NRPT is populated...\n")
|
||||
validateRegistry(t, nrptBaseGP, domains)
|
||||
ensureNoRulesInSubkey(t, nrptBaseLocal)
|
||||
|
||||
// 4. Delete fake GP key and refresh
|
||||
t.Logf("Deleting fake group policy key and refreshing...\n")
|
||||
deleteFakeGPKey(t)
|
||||
|
||||
err = regWatcher.watch()
|
||||
if err != nil {
|
||||
t.Fatalf("regWatcher.watch: %v\n", err)
|
||||
}
|
||||
|
||||
err = testDoRefresh()
|
||||
if err != nil {
|
||||
t.Fatalf("testDoRefresh: %v\n", err)
|
||||
}
|
||||
|
||||
err = regWatcher.wait()
|
||||
if err != nil {
|
||||
t.Fatalf("regWatcher.wait: %v\n", err)
|
||||
}
|
||||
|
||||
// 5. Check that local NRPT is populated and GP is empty
|
||||
t.Logf("Validating that local NRPT is populated...\n")
|
||||
validateRegistry(t, nrptBaseLocal, domains)
|
||||
ensureNoRulesInSubkey(t, nrptBaseGP)
|
||||
|
||||
// 6. Cleanup
|
||||
t.Logf("Cleaning up...\n")
|
||||
err = mgr.setSplitDNS(nil, domains)
|
||||
if err != nil {
|
||||
t.Fatalf("setSplitDNS: %v\n", err)
|
||||
}
|
||||
ensureNoRules(t)
|
||||
}
|
||||
|
||||
func checkGPNotificationsWork(t *testing.T) {
|
||||
// Test to ensure that RegisterGPNotification work on this machine,
|
||||
// otherwise this test will fail.
|
||||
@ -83,11 +194,18 @@ func runTest(t *testing.T, isLocal bool) {
|
||||
t.Fatalf("windows.GenerateGUID: %v\n", err)
|
||||
}
|
||||
|
||||
delIfKey, err := createFakeInterfaceKey(t, fakeInterface)
|
||||
if err != nil {
|
||||
t.Fatalf("createFakeInterfaceKey: %v\n", err)
|
||||
}
|
||||
defer delIfKey()
|
||||
|
||||
cfg, err := NewOSConfigurator(logf, fakeInterface.String())
|
||||
if err != nil {
|
||||
t.Fatalf("NewOSConfigurator: %v\n", err)
|
||||
}
|
||||
mgr := cfg.(windowsManager)
|
||||
defer mgr.Close()
|
||||
|
||||
usingGP := mgr.nrptDB.writeAsGP
|
||||
if isLocal == usingGP {
|
||||
@ -99,25 +217,7 @@ func runTest(t *testing.T, isLocal bool) {
|
||||
|
||||
resolvers := []netaddr.IP{netaddr.MustParseIP("1.1.1.1")}
|
||||
|
||||
domains := make([]dnsname.FQDN, 0, 2*nrptMaxDomainsPerRule+1)
|
||||
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
// Just generate a bunch of random subdomains
|
||||
for len(domains) < cap(domains) {
|
||||
l := r.Intn(19) + 1
|
||||
b := make([]byte, l)
|
||||
for i, _ := range b {
|
||||
b[i] = charset[r.Intn(len(charset))]
|
||||
}
|
||||
d := string(b) + ".example.com"
|
||||
fqdn, err := dnsname.ToFQDN(d)
|
||||
if err != nil {
|
||||
t.Fatalf("dnsname.ToFQDN: %v\n", err)
|
||||
}
|
||||
domains = append(domains, fqdn)
|
||||
}
|
||||
domains := genRandomSubdomains(t, 2*nrptMaxDomainsPerRule+1)
|
||||
|
||||
cases := []int{
|
||||
1,
|
||||
@ -238,6 +338,32 @@ func deleteFakeGPKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createFakeInterfaceKey(t *testing.T, guid windows.GUID) (func(), error) {
|
||||
basePaths := []string{ipv4RegBase, ipv6RegBase}
|
||||
keyPaths := make([]string, 0, len(basePaths))
|
||||
|
||||
for _, basePath := range basePaths {
|
||||
keyPath := fmt.Sprintf(`%s\Interfaces\%s`, basePath, guid)
|
||||
key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key.Close()
|
||||
|
||||
keyPaths = append(keyPaths, keyPath)
|
||||
}
|
||||
|
||||
result := func() {
|
||||
for _, keyPath := range keyPaths {
|
||||
if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyPath); err != nil {
|
||||
t.Fatalf("deleting fake interface key \"%s\": %v\n", keyPath, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func ensureNoRules(t *testing.T) {
|
||||
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
|
||||
if ruleIDs != nil {
|
||||
@ -263,11 +389,29 @@ func ensureNoRulesInSubkey(t *testing.T, base string) {
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyName, registry.READ)
|
||||
if err == nil {
|
||||
key.Close()
|
||||
}
|
||||
if err != registry.ErrNotExist {
|
||||
} else if err != registry.ErrNotExist {
|
||||
t.Fatalf("%s: %q, want %q\n", keyName, err, registry.ErrNotExist)
|
||||
}
|
||||
}
|
||||
|
||||
if base == nrptBaseGP {
|
||||
// When dealing with the group policy subkey, we want the base key to
|
||||
// also be absent.
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, base, registry.READ)
|
||||
if err == nil {
|
||||
key.Close()
|
||||
|
||||
isEmpty, err := isPolicyConfigSubkeyEmpty()
|
||||
if err != nil {
|
||||
t.Fatalf("isPolicyConfigSubkeyEmpty: %v", err)
|
||||
}
|
||||
if isEmpty {
|
||||
t.Errorf("Unexpectedly found group policy key\n")
|
||||
}
|
||||
} else if err != registry.ErrNotExist {
|
||||
t.Errorf("Group policy key error: %q, want %q\n", err, registry.ErrNotExist)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ensureNoSingleRule(t *testing.T, base string) {
|
||||
@ -332,6 +476,40 @@ func getSavedDomainsForRule(base, ruleID string) ([]string, error) {
|
||||
return result, err
|
||||
}
|
||||
|
||||
func genRandomSubdomains(t *testing.T, n int) []dnsname.FQDN {
|
||||
domains := make([]dnsname.FQDN, 0, n)
|
||||
|
||||
seed := time.Now().UnixNano()
|
||||
t.Logf("genRandomSubdomains(%d) seed: %v\n", n, seed)
|
||||
|
||||
r := rand.New(rand.NewSource(seed))
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
for len(domains) < cap(domains) {
|
||||
l := r.Intn(19) + 1
|
||||
b := make([]byte, l)
|
||||
for i, _ := range b {
|
||||
b[i] = charset[r.Intn(len(charset))]
|
||||
}
|
||||
d := string(b) + ".example.com"
|
||||
fqdn, err := dnsname.ToFQDN(d)
|
||||
if err != nil {
|
||||
t.Fatalf("dnsname.ToFQDN: %v\n", err)
|
||||
}
|
||||
domains = append(domains, fqdn)
|
||||
}
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
func testDoRefresh() (err error) {
|
||||
r, _, e := procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE))
|
||||
if r == 0 {
|
||||
err = e
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// gpNotificationTracker registers with the Windows policy engine and receives
|
||||
// notifications when policy refreshes occur.
|
||||
type gpNotificationTracker struct {
|
||||
@ -384,3 +562,103 @@ func (trk *gpNotificationTracker) Close() error {
|
||||
trk.event = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
type regKeyWatcher struct {
|
||||
keyLocal registry.Key
|
||||
keyGP registry.Key
|
||||
evtLocal windows.Handle
|
||||
evtGP windows.Handle
|
||||
}
|
||||
|
||||
func newRegKeyWatcher() (*regKeyWatcher, error) {
|
||||
var err error
|
||||
|
||||
keyLocal, _, err := registry.CreateKey(registry.LOCAL_MACHINE, nrptBaseLocal, registry.READ)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
keyLocal.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Monitor dnsBaseGP instead of nrptBaseGP, since the latter will be
|
||||
// repeatedly created and destroyed throughout the course of the test.
|
||||
keyGP, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsBaseGP, registry.READ)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
keyGP.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
evtLocal, err := windows.CreateEvent(nil, 0, 0, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
windows.CloseHandle(evtLocal)
|
||||
}
|
||||
}()
|
||||
|
||||
evtGP, err := windows.CreateEvent(nil, 0, 0, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := ®KeyWatcher{
|
||||
keyLocal: keyLocal,
|
||||
keyGP: keyGP,
|
||||
evtLocal: evtLocal,
|
||||
evtGP: evtGP,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (rw *regKeyWatcher) watch() error {
|
||||
// We can make these waits thread-agnostic because the tests that use this code must already run on Windows 10+
|
||||
err := windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyLocal), true,
|
||||
windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtLocal, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyGP), true,
|
||||
windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtGP, true)
|
||||
}
|
||||
|
||||
func (rw *regKeyWatcher) wait() error {
|
||||
handles := []windows.Handle{
|
||||
rw.evtLocal,
|
||||
rw.evtGP,
|
||||
}
|
||||
|
||||
waitCode, err := windows.WaitForMultipleObjects(
|
||||
handles,
|
||||
true, // Wait for both events to signal before resuming.
|
||||
10000, // 10 seconds (as milliseconds)
|
||||
)
|
||||
|
||||
const WAIT_TIMEOUT = 0x102
|
||||
switch waitCode {
|
||||
case WAIT_TIMEOUT:
|
||||
return context.DeadlineExceeded
|
||||
case windows.WAIT_FAILED:
|
||||
return err
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *regKeyWatcher) Close() error {
|
||||
rw.keyLocal.Close()
|
||||
rw.keyGP.Close()
|
||||
windows.CloseHandle(rw.evtLocal)
|
||||
windows.CloseHandle(rw.evtGP)
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user