From ecea6cb994e228f20d03c0f99b7a5f27e20cbac6 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 18 Apr 2022 12:50:26 -0700 Subject: [PATCH] net/dns/resolver: make DoH dialer use existing dnscache happy eyeball dialer Simplify the ability to reason about the DoH dialing code by reusing the dnscache's dialer we already have. Also, reduce the scope of the "ip" variable we don't want to close over. This necessarily adds a new field to dnscache.Resolver: SingleHostStaticResult, for when the caller already knows the IPs to be returned. Change-Id: I9f2aef7926f649137a5a3e63eebad6a3fffa48c0 Signed-off-by: Brad Fitzpatrick --- net/dns/resolver/forwarder.go | 52 +++++++++++++++++++++--------- net/dns/resolver/forwarder_test.go | 20 ++++++++++++ net/dnscache/dnscache.go | 25 ++++++++++++++ net/dnscache/dnscache_test.go | 31 ++++++++++++++++++ 4 files changed, 113 insertions(+), 15 deletions(-) diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 310293554..addfea4e0 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -15,6 +15,7 @@ "math/rand" "net" "net/http" + "net/url" "runtime" "sort" "strconv" @@ -26,6 +27,7 @@ "inet.af/netaddr" "tailscale.com/hostinfo" "tailscale.com/net/dns/publicdns" + "tailscale.com/net/dnscache" "tailscale.com/net/neterror" "tailscale.com/net/netns" "tailscale.com/net/tsdial" @@ -332,21 +334,47 @@ func (f *forwarder) packetListener(ip netaddr.IP) (packetListener, error) { return lc, nil } +// getKnownDoHClient returns an HTTP client for a DoH provider (such as Google +// or Cloudflare DNS), as a function of one of its (usually four) IPs. +// +// The provided IP is only used to determine the DoH provider; it is not +// prioritized among the set of IPs that are used by the provider. func (f *forwarder) getKnownDoHClient(ip netaddr.IP) (urlBase string, c *http.Client, ok bool) { urlBase, ok = publicdns.KnownDoH()[ip] if !ok { - return + return "", nil, false } + c, ok = f.getKnownDoHClientForProvider(urlBase) + if !ok { + return "", nil, false + } + return urlBase, c, true +} +// getKnownDoHClientForProvider returns an HTTP client for a specific DoH +// provider named by its DoH base URL (like "https://dns.google/dns-query"). +// +// The returned client race/Happy Eyeballs dials all IPs for urlBase (usually +// 4), as statically known by the publicdns package. +func (f *forwarder) getKnownDoHClientForProvider(urlBase string) (c *http.Client, ok bool) { f.mu.Lock() defer f.mu.Unlock() if c, ok := f.dohClient[urlBase]; ok { - return urlBase, c, true + return c, true } - if f.dohClient == nil { - f.dohClient = map[string]*http.Client{} + allIPs := publicdns.DoHIPsOfBase()[urlBase] + if len(allIPs) == 0 { + return nil, false + } + dohURL, err := url.Parse(urlBase) + if err != nil { + return nil, false } nsDialer := netns.NewDialer(f.logf) + dialer := dnscache.Dialer(nsDialer.DialContext, &dnscache.Resolver{ + SingleHost: dohURL.Hostname(), + SingleHostStaticResult: allIPs, + }) c = &http.Client{ Transport: &http.Transport{ IdleConnTimeout: dohTransportTimeout, @@ -354,21 +382,15 @@ func (f *forwarder) getKnownDoHClient(ip netaddr.IP) (urlBase string, c *http.Cl if !strings.HasPrefix(netw, "tcp") { return nil, fmt.Errorf("unexpected network %q", netw) } - c, err := nsDialer.DialContext(ctx, "tcp", net.JoinHostPort(ip.String(), "443")) - // If v4 failed, try an equivalent v6 also in the time remaining. - if err != nil && ctx.Err() == nil { - if ip6, ok := publicdns.DoHV6(urlBase); ok && ip.Is4() { - if c6, err := nsDialer.DialContext(ctx, "tcp", net.JoinHostPort(ip6.String(), "443")); err == nil { - return c6, nil - } - } - } - return c, err + return dialer(ctx, netw, addr) }, }, } + if f.dohClient == nil { + f.dohClient = map[string]*http.Client{} + } f.dohClient[urlBase] = c - return urlBase, c, true + return c, true } const dohType = "application/dns-message" diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index fb19ea8e6..f90761af6 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -5,6 +5,7 @@ package resolver import ( + "flag" "fmt" "net" "reflect" @@ -169,6 +170,25 @@ func TestMaxDoHInFlight(t *testing.T) { } } +var testDNS = flag.Bool("test-dns", false, "run tests that require a working DNS server") + +func TestGetKnownDoHClientForProvider(t *testing.T) { + var fwd forwarder + c, ok := fwd.getKnownDoHClientForProvider("https://dns.google/dns-query") + if !ok { + t.Fatal("not found") + } + if !*testDNS { + t.Skip("skipping without --test-dns") + } + res, err := c.Head("https://dns.google/") + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + t.Logf("Got: %+v", res) +} + func BenchmarkNameFromQuery(b *testing.B) { builder := dns.NewBuilder(nil, dns.Header{}) builder.StartQuestions() diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index c1536e0ad..830b894b3 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -72,6 +72,15 @@ type Resolver struct { // if a refresh fails. UseLastGood bool + // SingleHostStaticResult, if non-nil, is the static result of IPs that is returned + // by Resolver.LookupIP for any hostname. When non-nil, SingleHost must also be + // set with the expected name. + SingleHostStaticResult []netaddr.IP + + // SingleHost is the hostname that SingleHostStaticResult is for. + // It is required when SingleHostStaticResult is present. + SingleHost string + sf singleflight.Group mu sync.Mutex @@ -108,6 +117,22 @@ func (r *Resolver) ttl() time.Duration { // If err is nil, ip will be non-nil. The v6 address may be nil even // with a nil error. func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 net.IP, allIPs []net.IPAddr, err error) { + if r.SingleHostStaticResult != nil { + if r.SingleHost != host { + return nil, nil, nil, fmt.Errorf("dnscache: unexpected hostname %q doesn't match expected %q", host, r.SingleHost) + } + for _, naIP := range r.SingleHostStaticResult { + ipa := naIP.IPAddr() + if ip == nil && naIP.Is4() { + ip = ipa.IP + } + if v6 == nil && naIP.Is6() { + v6 = ipa.IP + } + allIPs = append(allIPs, *ipa) + } + return + } if ip := net.ParseIP(host); ip != nil { if ip4 := ip.To4(); ip4 != nil { return ip4, nil, []net.IPAddr{{IP: ip4}}, nil diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index 10cfd5398..6b005500d 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -8,6 +8,7 @@ "context" "errors" "flag" + "fmt" "net" "reflect" "testing" @@ -110,3 +111,33 @@ func TestDialCall_uniqueIPs(t *testing.T) { t.Errorf("got %v; want %v", got, want) } } + +func TestResolverAllHostStaticResult(t *testing.T) { + r := &Resolver{ + SingleHost: "foo.bar", + SingleHostStaticResult: []netaddr.IP{ + netaddr.MustParseIP("2001:4860:4860::8888"), + netaddr.MustParseIP("2001:4860:4860::8844"), + netaddr.MustParseIP("8.8.8.8"), + netaddr.MustParseIP("8.8.4.4"), + }, + } + ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar") + if err != nil { + t.Fatal(err) + } + if got, want := ip4.String(), "8.8.8.8"; got != want { + t.Errorf("ip4 got %q; want %q", got, want) + } + if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { + t.Errorf("ip4 got %q; want %q", got, want) + } + if got, want := fmt.Sprintf("%q", allIPs), `[{"2001:4860:4860::8888" ""} {"2001:4860:4860::8844" ""} {"8.8.8.8" ""} {"8.8.4.4" ""}]`; got != want { + t.Errorf("allIPs got %q; want %q", got, want) + } + + _, _, _, err = r.LookupIP(context.Background(), "bad") + if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want { + t.Errorf("bad dial error got %q; want %q", got, want) + } +}