net/dnsfallback: more explicitly pass through logf function

Redoes the approach from #5550 and #7539 to explicitly pass in the logf
function, instead of having global state that can be overridden.

Signed-off-by: Mihai Parparita <mihai@tailscale.com>
This commit is contained in:
Mihai Parparita
2023-04-17 10:58:40 -07:00
committed by Mihai Parparita
parent 28cb1221ba
commit 9a655a1d58
8 changed files with 20 additions and 66 deletions

View File

@ -13,7 +13,6 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/netip"
@ -27,13 +26,18 @@ import (
"tailscale.com/net/netns"
"tailscale.com/net/tlsdial"
"tailscale.com/net/tshttpproxy"
"tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/types/logger"
"tailscale.com/util/slicesx"
)
func Lookup(ctx context.Context, host string) ([]netip.Addr, error) {
func Lookup(logf logger.Logf) func(ctx context.Context, host string) ([]netip.Addr, error) {
return func(ctx context.Context, host string) ([]netip.Addr, error) {
return lookup(ctx, host, logf)
}
}
func lookup(ctx context.Context, host string, logf logger.Logf) ([]netip.Addr, error) {
if ip, err := netip.ParseAddr(host); err == nil && ip.IsValid() {
return []netip.Addr{ip}, nil
}
@ -81,7 +85,7 @@ func Lookup(ctx context.Context, host string) ([]netip.Addr, error) {
logf("trying bootstrapDNS(%q, %q) for %q ...", cand.dnsName, cand.ip, host)
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
dm, err := bootstrapDNSMap(ctx, cand.dnsName, cand.ip, host)
dm, err := bootstrapDNSMap(ctx, cand.dnsName, cand.ip, host, logf)
if err != nil {
logf("bootstrapDNS(%q, %q) for %q error: %v", cand.dnsName, cand.ip, host, err)
continue
@ -100,7 +104,7 @@ func Lookup(ctx context.Context, host string) ([]netip.Addr, error) {
// serverName and serverIP of are, say, "derpN.tailscale.com".
// queryName is the name being sought (e.g. "controlplane.tailscale.com"), passed as hint.
func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netip.Addr, queryName string) (dnsMap, error) {
func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netip.Addr, queryName string, logf logger.Logf) (dnsMap, error) {
dialer := netns.NewDialer(logf)
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Proxy = tshttpproxy.ProxyFromEnvironment
@ -194,7 +198,7 @@ var cachePath string
// UpdateCache stores the DERP map cache back to disk.
//
// The caller must not mutate 'c' after calling this function.
func UpdateCache(c *tailcfg.DERPMap) {
func UpdateCache(c *tailcfg.DERPMap, logf logger.Logf) {
// Don't do anything if nothing changed.
curr := cachedDERPMap.Load()
if reflect.DeepEqual(curr, c) {
@ -227,7 +231,7 @@ func UpdateCache(c *tailcfg.DERPMap) {
//
// This function should be called before any calls to UpdateCache, as it is not
// concurrency-safe.
func SetCachePath(path string) {
func SetCachePath(path string, logf logger.Logf) {
cachePath = path
f, err := os.Open(path)
@ -246,23 +250,3 @@ func SetCachePath(path string) {
cachedDERPMap.Store(dm)
logf("[v2] dnsfallback: SetCachePath loaded cached DERP map")
}
// logfunc stores the logging function to use for this package.
var logfunc syncs.AtomicValue[logger.Logf]
// SetLogger sets the logging function that this package will use, and returns
// the old value (which may be nil).
//
// If this function is never called, or if this function is called with a nil
// value, 'log.Printf' will be used to print logs.
func SetLogger(log logger.Logf) (old logger.Logf) {
return logfunc.Swap(log)
}
func logf(format string, args ...any) {
if lf := logfunc.Load(); lf != nil {
lf(format, args...)
} else {
log.Printf(format, args...)
}
}