net/dns: add Windows group policy notifications to the NRPT rule manager

As discussed in previous PRs, we can register for notifications when group
policies are updated and act accordingly.

This patch changes nrptRuleDatabase to receive notifications that group policy
has changed and automatically move our NRPT rules between the local and
group policy subkeys as needed.

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
Aaron Klotz
2022-06-29 15:02:23 -06:00
parent f17873e0f4
commit 1cae618b03
4 changed files with 631 additions and 62 deletions

View File

@ -5,6 +5,7 @@
package dns
import (
"context"
"fmt"
"math/rand"
"strings"
@ -20,11 +21,6 @@ import (
const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}"
var (
procRegisterGPNotification = libUserenv.NewProc("RegisterGPNotification")
procUnregisterGPNotification = libUserenv.NewProc("UnregisterGPNotification")
)
func TestManagerWindowsLocal(t *testing.T) {
if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
t.Skipf("test requires running as elevated user on Windows 10+")
@ -53,6 +49,121 @@ func TestManagerWindowsGP(t *testing.T) {
runTest(t, false)
}
func TestManagerWindowsGPMove(t *testing.T) {
if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
t.Skipf("test requires running as elevated user on Windows 10+")
}
checkGPNotificationsWork(t)
logf := func(format string, args ...any) {
t.Logf(format, args...)
}
fakeInterface, err := windows.GenerateGUID()
if err != nil {
t.Fatalf("windows.GenerateGUID: %v\n", err)
}
delIfKey, err := createFakeInterfaceKey(t, fakeInterface)
if err != nil {
t.Fatalf("createFakeInterfaceKey: %v\n", err)
}
defer delIfKey()
cfg, err := NewOSConfigurator(logf, fakeInterface.String())
if err != nil {
t.Fatalf("NewOSConfigurator: %v\n", err)
}
mgr := cfg.(windowsManager)
defer mgr.Close()
usingGP := mgr.nrptDB.writeAsGP
if usingGP {
t.Fatalf("usingGP %v, want %v\n", usingGP, false)
}
regWatcher, err := newRegKeyWatcher()
if err != nil {
t.Fatalf("newRegKeyWatcher error %v\n", err)
}
// Upon initialization of cfg, we should not have any NRPT rules
ensureNoRules(t)
resolvers := []netaddr.IP{netaddr.MustParseIP("1.1.1.1")}
domains := genRandomSubdomains(t, 1)
// 1. Populate local NRPT
err = mgr.setSplitDNS(resolvers, domains)
if err != nil {
t.Fatalf("setSplitDNS: %v\n", err)
}
t.Logf("Validating that local NRPT is populated...\n")
validateRegistry(t, nrptBaseLocal, domains)
ensureNoRulesInSubkey(t, nrptBaseGP)
// 2. Create fake GP key and refresh
t.Logf("Creating fake group policy key and refreshing...\n")
err = createFakeGPKey()
if err != nil {
t.Fatalf("createFakeGPKey: %v\n", err)
}
err = regWatcher.watch()
if err != nil {
t.Fatalf("regWatcher.watch: %v\n", err)
}
err = testDoRefresh()
if err != nil {
t.Fatalf("testDoRefresh: %v\n", err)
}
err = regWatcher.wait()
if err != nil {
t.Fatalf("regWatcher.wait: %v\n", err)
}
// 3. Check that local NRPT is empty and GP is populated
t.Logf("Validating that group policy NRPT is populated...\n")
validateRegistry(t, nrptBaseGP, domains)
ensureNoRulesInSubkey(t, nrptBaseLocal)
// 4. Delete fake GP key and refresh
t.Logf("Deleting fake group policy key and refreshing...\n")
deleteFakeGPKey(t)
err = regWatcher.watch()
if err != nil {
t.Fatalf("regWatcher.watch: %v\n", err)
}
err = testDoRefresh()
if err != nil {
t.Fatalf("testDoRefresh: %v\n", err)
}
err = regWatcher.wait()
if err != nil {
t.Fatalf("regWatcher.wait: %v\n", err)
}
// 5. Check that local NRPT is populated and GP is empty
t.Logf("Validating that local NRPT is populated...\n")
validateRegistry(t, nrptBaseLocal, domains)
ensureNoRulesInSubkey(t, nrptBaseGP)
// 6. Cleanup
t.Logf("Cleaning up...\n")
err = mgr.setSplitDNS(nil, domains)
if err != nil {
t.Fatalf("setSplitDNS: %v\n", err)
}
ensureNoRules(t)
}
func checkGPNotificationsWork(t *testing.T) {
// Test to ensure that RegisterGPNotification work on this machine,
// otherwise this test will fail.
@ -83,11 +194,18 @@ func runTest(t *testing.T, isLocal bool) {
t.Fatalf("windows.GenerateGUID: %v\n", err)
}
delIfKey, err := createFakeInterfaceKey(t, fakeInterface)
if err != nil {
t.Fatalf("createFakeInterfaceKey: %v\n", err)
}
defer delIfKey()
cfg, err := NewOSConfigurator(logf, fakeInterface.String())
if err != nil {
t.Fatalf("NewOSConfigurator: %v\n", err)
}
mgr := cfg.(windowsManager)
defer mgr.Close()
usingGP := mgr.nrptDB.writeAsGP
if isLocal == usingGP {
@ -99,25 +217,7 @@ func runTest(t *testing.T, isLocal bool) {
resolvers := []netaddr.IP{netaddr.MustParseIP("1.1.1.1")}
domains := make([]dnsname.FQDN, 0, 2*nrptMaxDomainsPerRule+1)
r := rand.New(rand.NewSource(time.Now().UnixNano()))
const charset = "abcdefghijklmnopqrstuvwxyz"
// Just generate a bunch of random subdomains
for len(domains) < cap(domains) {
l := r.Intn(19) + 1
b := make([]byte, l)
for i, _ := range b {
b[i] = charset[r.Intn(len(charset))]
}
d := string(b) + ".example.com"
fqdn, err := dnsname.ToFQDN(d)
if err != nil {
t.Fatalf("dnsname.ToFQDN: %v\n", err)
}
domains = append(domains, fqdn)
}
domains := genRandomSubdomains(t, 2*nrptMaxDomainsPerRule+1)
cases := []int{
1,
@ -238,6 +338,32 @@ func deleteFakeGPKey(t *testing.T) {
}
}
func createFakeInterfaceKey(t *testing.T, guid windows.GUID) (func(), error) {
basePaths := []string{ipv4RegBase, ipv6RegBase}
keyPaths := make([]string, 0, len(basePaths))
for _, basePath := range basePaths {
keyPath := fmt.Sprintf(`%s\Interfaces\%s`, basePath, guid)
key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
if err != nil {
return nil, err
}
key.Close()
keyPaths = append(keyPaths, keyPath)
}
result := func() {
for _, keyPath := range keyPaths {
if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyPath); err != nil {
t.Fatalf("deleting fake interface key \"%s\": %v\n", keyPath, err)
}
}
}
return result, nil
}
func ensureNoRules(t *testing.T) {
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
if ruleIDs != nil {
@ -263,11 +389,29 @@ func ensureNoRulesInSubkey(t *testing.T, base string) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyName, registry.READ)
if err == nil {
key.Close()
}
if err != registry.ErrNotExist {
} else if err != registry.ErrNotExist {
t.Fatalf("%s: %q, want %q\n", keyName, err, registry.ErrNotExist)
}
}
if base == nrptBaseGP {
// When dealing with the group policy subkey, we want the base key to
// also be absent.
key, err := registry.OpenKey(registry.LOCAL_MACHINE, base, registry.READ)
if err == nil {
key.Close()
isEmpty, err := isPolicyConfigSubkeyEmpty()
if err != nil {
t.Fatalf("isPolicyConfigSubkeyEmpty: %v", err)
}
if isEmpty {
t.Errorf("Unexpectedly found group policy key\n")
}
} else if err != registry.ErrNotExist {
t.Errorf("Group policy key error: %q, want %q\n", err, registry.ErrNotExist)
}
}
}
func ensureNoSingleRule(t *testing.T, base string) {
@ -332,6 +476,40 @@ func getSavedDomainsForRule(base, ruleID string) ([]string, error) {
return result, err
}
func genRandomSubdomains(t *testing.T, n int) []dnsname.FQDN {
domains := make([]dnsname.FQDN, 0, n)
seed := time.Now().UnixNano()
t.Logf("genRandomSubdomains(%d) seed: %v\n", n, seed)
r := rand.New(rand.NewSource(seed))
const charset = "abcdefghijklmnopqrstuvwxyz"
for len(domains) < cap(domains) {
l := r.Intn(19) + 1
b := make([]byte, l)
for i, _ := range b {
b[i] = charset[r.Intn(len(charset))]
}
d := string(b) + ".example.com"
fqdn, err := dnsname.ToFQDN(d)
if err != nil {
t.Fatalf("dnsname.ToFQDN: %v\n", err)
}
domains = append(domains, fqdn)
}
return domains
}
func testDoRefresh() (err error) {
r, _, e := procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE))
if r == 0 {
err = e
}
return err
}
// gpNotificationTracker registers with the Windows policy engine and receives
// notifications when policy refreshes occur.
type gpNotificationTracker struct {
@ -384,3 +562,103 @@ func (trk *gpNotificationTracker) Close() error {
trk.event = 0
return nil
}
type regKeyWatcher struct {
keyLocal registry.Key
keyGP registry.Key
evtLocal windows.Handle
evtGP windows.Handle
}
func newRegKeyWatcher() (*regKeyWatcher, error) {
var err error
keyLocal, _, err := registry.CreateKey(registry.LOCAL_MACHINE, nrptBaseLocal, registry.READ)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
keyLocal.Close()
}
}()
// Monitor dnsBaseGP instead of nrptBaseGP, since the latter will be
// repeatedly created and destroyed throughout the course of the test.
keyGP, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsBaseGP, registry.READ)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
keyGP.Close()
}
}()
evtLocal, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
windows.CloseHandle(evtLocal)
}
}()
evtGP, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return nil, err
}
result := &regKeyWatcher{
keyLocal: keyLocal,
keyGP: keyGP,
evtLocal: evtLocal,
evtGP: evtGP,
}
return result, nil
}
func (rw *regKeyWatcher) watch() error {
// We can make these waits thread-agnostic because the tests that use this code must already run on Windows 10+
err := windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyLocal), true,
windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtLocal, true)
if err != nil {
return err
}
return windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyGP), true,
windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtGP, true)
}
func (rw *regKeyWatcher) wait() error {
handles := []windows.Handle{
rw.evtLocal,
rw.evtGP,
}
waitCode, err := windows.WaitForMultipleObjects(
handles,
true, // Wait for both events to signal before resuming.
10000, // 10 seconds (as milliseconds)
)
const WAIT_TIMEOUT = 0x102
switch waitCode {
case WAIT_TIMEOUT:
return context.DeadlineExceeded
case windows.WAIT_FAILED:
return err
default:
return nil
}
}
func (rw *regKeyWatcher) Close() error {
rw.keyLocal.Close()
rw.keyGP.Close()
windows.CloseHandle(rw.evtLocal)
windows.CloseHandle(rw.evtGP)
return nil
}