net/dns/resolver: make DoH dialer use existing dnscache happy eyeball dialer

Simplify the ability to reason about the DoH dialing code by reusing the
dnscache's dialer we already have.

Also, reduce the scope of the "ip" variable we don't want to close over.

This necessarily adds a new field to dnscache.Resolver:
SingleHostStaticResult, for when the caller already knows the IPs to be
returned.

Change-Id: I9f2aef7926f649137a5a3e63eebad6a3fffa48c0
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2022-04-18 12:50:26 -07:00
committed by Brad Fitzpatrick
parent e96dd00652
commit ecea6cb994
4 changed files with 113 additions and 15 deletions

View File

@ -72,6 +72,15 @@ type Resolver struct {
// if a refresh fails.
UseLastGood bool
// SingleHostStaticResult, if non-nil, is the static result of IPs that is returned
// by Resolver.LookupIP for any hostname. When non-nil, SingleHost must also be
// set with the expected name.
SingleHostStaticResult []netaddr.IP
// SingleHost is the hostname that SingleHostStaticResult is for.
// It is required when SingleHostStaticResult is present.
SingleHost string
sf singleflight.Group
mu sync.Mutex
@ -108,6 +117,22 @@ var debug = envknob.Bool("TS_DEBUG_DNS_CACHE")
// If err is nil, ip will be non-nil. The v6 address may be nil even
// with a nil error.
func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 net.IP, allIPs []net.IPAddr, err error) {
if r.SingleHostStaticResult != nil {
if r.SingleHost != host {
return nil, nil, nil, fmt.Errorf("dnscache: unexpected hostname %q doesn't match expected %q", host, r.SingleHost)
}
for _, naIP := range r.SingleHostStaticResult {
ipa := naIP.IPAddr()
if ip == nil && naIP.Is4() {
ip = ipa.IP
}
if v6 == nil && naIP.Is6() {
v6 = ipa.IP
}
allIPs = append(allIPs, *ipa)
}
return
}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
return ip4, nil, []net.IPAddr{{IP: ip4}}, nil

View File

@ -8,6 +8,7 @@ import (
"context"
"errors"
"flag"
"fmt"
"net"
"reflect"
"testing"
@ -110,3 +111,33 @@ func TestDialCall_uniqueIPs(t *testing.T) {
t.Errorf("got %v; want %v", got, want)
}
}
func TestResolverAllHostStaticResult(t *testing.T) {
r := &Resolver{
SingleHost: "foo.bar",
SingleHostStaticResult: []netaddr.IP{
netaddr.MustParseIP("2001:4860:4860::8888"),
netaddr.MustParseIP("2001:4860:4860::8844"),
netaddr.MustParseIP("8.8.8.8"),
netaddr.MustParseIP("8.8.4.4"),
},
}
ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar")
if err != nil {
t.Fatal(err)
}
if got, want := ip4.String(), "8.8.8.8"; got != want {
t.Errorf("ip4 got %q; want %q", got, want)
}
if got, want := ip6.String(), "2001:4860:4860::8888"; got != want {
t.Errorf("ip4 got %q; want %q", got, want)
}
if got, want := fmt.Sprintf("%q", allIPs), `[{"2001:4860:4860::8888" ""} {"2001:4860:4860::8844" ""} {"8.8.8.8" ""} {"8.8.4.4" ""}]`; got != want {
t.Errorf("allIPs got %q; want %q", got, want)
}
_, _, _, err = r.LookupIP(context.Background(), "bad")
if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want {
t.Errorf("bad dial error got %q; want %q", got, want)
}
}