diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 310293554..addfea4e0 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -15,6 +15,7 @@ "math/rand" "net" "net/http" + "net/url" "runtime" "sort" "strconv" @@ -26,6 +27,7 @@ "inet.af/netaddr" "tailscale.com/hostinfo" "tailscale.com/net/dns/publicdns" + "tailscale.com/net/dnscache" "tailscale.com/net/neterror" "tailscale.com/net/netns" "tailscale.com/net/tsdial" @@ -332,21 +334,47 @@ func (f *forwarder) packetListener(ip netaddr.IP) (packetListener, error) { return lc, nil } +// getKnownDoHClient returns an HTTP client for a DoH provider (such as Google +// or Cloudflare DNS), as a function of one of its (usually four) IPs. +// +// The provided IP is only used to determine the DoH provider; it is not +// prioritized among the set of IPs that are used by the provider. func (f *forwarder) getKnownDoHClient(ip netaddr.IP) (urlBase string, c *http.Client, ok bool) { urlBase, ok = publicdns.KnownDoH()[ip] if !ok { - return + return "", nil, false } + c, ok = f.getKnownDoHClientForProvider(urlBase) + if !ok { + return "", nil, false + } + return urlBase, c, true +} +// getKnownDoHClientForProvider returns an HTTP client for a specific DoH +// provider named by its DoH base URL (like "https://dns.google/dns-query"). +// +// The returned client race/Happy Eyeballs dials all IPs for urlBase (usually +// 4), as statically known by the publicdns package. +func (f *forwarder) getKnownDoHClientForProvider(urlBase string) (c *http.Client, ok bool) { f.mu.Lock() defer f.mu.Unlock() if c, ok := f.dohClient[urlBase]; ok { - return urlBase, c, true + return c, true } - if f.dohClient == nil { - f.dohClient = map[string]*http.Client{} + allIPs := publicdns.DoHIPsOfBase()[urlBase] + if len(allIPs) == 0 { + return nil, false + } + dohURL, err := url.Parse(urlBase) + if err != nil { + return nil, false } nsDialer := netns.NewDialer(f.logf) + dialer := dnscache.Dialer(nsDialer.DialContext, &dnscache.Resolver{ + SingleHost: dohURL.Hostname(), + SingleHostStaticResult: allIPs, + }) c = &http.Client{ Transport: &http.Transport{ IdleConnTimeout: dohTransportTimeout, @@ -354,21 +382,15 @@ func (f *forwarder) getKnownDoHClient(ip netaddr.IP) (urlBase string, c *http.Cl if !strings.HasPrefix(netw, "tcp") { return nil, fmt.Errorf("unexpected network %q", netw) } - c, err := nsDialer.DialContext(ctx, "tcp", net.JoinHostPort(ip.String(), "443")) - // If v4 failed, try an equivalent v6 also in the time remaining. - if err != nil && ctx.Err() == nil { - if ip6, ok := publicdns.DoHV6(urlBase); ok && ip.Is4() { - if c6, err := nsDialer.DialContext(ctx, "tcp", net.JoinHostPort(ip6.String(), "443")); err == nil { - return c6, nil - } - } - } - return c, err + return dialer(ctx, netw, addr) }, }, } + if f.dohClient == nil { + f.dohClient = map[string]*http.Client{} + } f.dohClient[urlBase] = c - return urlBase, c, true + return c, true } const dohType = "application/dns-message" diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index fb19ea8e6..f90761af6 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -5,6 +5,7 @@ package resolver import ( + "flag" "fmt" "net" "reflect" @@ -169,6 +170,25 @@ func TestMaxDoHInFlight(t *testing.T) { } } +var testDNS = flag.Bool("test-dns", false, "run tests that require a working DNS server") + +func TestGetKnownDoHClientForProvider(t *testing.T) { + var fwd forwarder + c, ok := fwd.getKnownDoHClientForProvider("https://dns.google/dns-query") + if !ok { + t.Fatal("not found") + } + if !*testDNS { + t.Skip("skipping without --test-dns") + } + res, err := c.Head("https://dns.google/") + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + t.Logf("Got: %+v", res) +} + func BenchmarkNameFromQuery(b *testing.B) { builder := dns.NewBuilder(nil, dns.Header{}) builder.StartQuestions() diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index c1536e0ad..830b894b3 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -72,6 +72,15 @@ type Resolver struct { // if a refresh fails. UseLastGood bool + // SingleHostStaticResult, if non-nil, is the static result of IPs that is returned + // by Resolver.LookupIP for any hostname. When non-nil, SingleHost must also be + // set with the expected name. + SingleHostStaticResult []netaddr.IP + + // SingleHost is the hostname that SingleHostStaticResult is for. + // It is required when SingleHostStaticResult is present. + SingleHost string + sf singleflight.Group mu sync.Mutex @@ -108,6 +117,22 @@ func (r *Resolver) ttl() time.Duration { // If err is nil, ip will be non-nil. The v6 address may be nil even // with a nil error. func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 net.IP, allIPs []net.IPAddr, err error) { + if r.SingleHostStaticResult != nil { + if r.SingleHost != host { + return nil, nil, nil, fmt.Errorf("dnscache: unexpected hostname %q doesn't match expected %q", host, r.SingleHost) + } + for _, naIP := range r.SingleHostStaticResult { + ipa := naIP.IPAddr() + if ip == nil && naIP.Is4() { + ip = ipa.IP + } + if v6 == nil && naIP.Is6() { + v6 = ipa.IP + } + allIPs = append(allIPs, *ipa) + } + return + } if ip := net.ParseIP(host); ip != nil { if ip4 := ip.To4(); ip4 != nil { return ip4, nil, []net.IPAddr{{IP: ip4}}, nil diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index 10cfd5398..6b005500d 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -8,6 +8,7 @@ "context" "errors" "flag" + "fmt" "net" "reflect" "testing" @@ -110,3 +111,33 @@ func TestDialCall_uniqueIPs(t *testing.T) { t.Errorf("got %v; want %v", got, want) } } + +func TestResolverAllHostStaticResult(t *testing.T) { + r := &Resolver{ + SingleHost: "foo.bar", + SingleHostStaticResult: []netaddr.IP{ + netaddr.MustParseIP("2001:4860:4860::8888"), + netaddr.MustParseIP("2001:4860:4860::8844"), + netaddr.MustParseIP("8.8.8.8"), + netaddr.MustParseIP("8.8.4.4"), + }, + } + ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar") + if err != nil { + t.Fatal(err) + } + if got, want := ip4.String(), "8.8.8.8"; got != want { + t.Errorf("ip4 got %q; want %q", got, want) + } + if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { + t.Errorf("ip4 got %q; want %q", got, want) + } + if got, want := fmt.Sprintf("%q", allIPs), `[{"2001:4860:4860::8888" ""} {"2001:4860:4860::8844" ""} {"8.8.8.8" ""} {"8.8.4.4" ""}]`; got != want { + t.Errorf("allIPs got %q; want %q", got, want) + } + + _, _, _, err = r.LookupIP(context.Background(), "bad") + if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want { + t.Errorf("bad dial error got %q; want %q", got, want) + } +}