control/controlclient, net/{dnscache,dnsfallback}: add DNS fallback mechanism
Updates #1405 Updates #1403 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:

committed by
Brad Fitzpatrick

parent
03c344333e
commit
9df4185c94
@ -8,6 +8,8 @@ package dnscache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
@ -18,6 +20,7 @@ import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
var single = &Resolver{
|
||||
@ -55,6 +58,10 @@ type Resolver struct {
|
||||
// If nil, net.DefaultResolver is used.
|
||||
Forward *net.Resolver
|
||||
|
||||
// LookupIPFallback optionally provides a backup DNS mechanism
|
||||
// to use if Forward returns an error or no results.
|
||||
LookupIPFallback func(ctx context.Context, host string) ([]netaddr.IP, error)
|
||||
|
||||
// TTL is how long to keep entries cached
|
||||
//
|
||||
// If zero, a default (currently 10 minutes) is used.
|
||||
@ -198,6 +205,18 @@ func (r *Resolver) lookupIP(host string) (ip, ip6 net.IP, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host))
|
||||
defer cancel()
|
||||
ips, err := r.fwd().LookupIPAddr(ctx, host)
|
||||
if (err != nil || len(ips) == 0) && r.LookupIPFallback != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
var fips []netaddr.IP
|
||||
fips, err = r.LookupIPFallback(ctx, host)
|
||||
if err == nil {
|
||||
ips = nil
|
||||
for _, fip := range fips {
|
||||
ips = append(ips, *fip.IPAddr())
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -269,13 +288,33 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
|
||||
|
||||
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
|
||||
func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, address string) (retConn net.Conn, ret error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
// Bogus. But just let the real dialer return an error rather than
|
||||
// inventing a similar one.
|
||||
return fwd(ctx, network, address)
|
||||
}
|
||||
defer func() {
|
||||
// On any failure, assume our DNS is wrong and try our fallback, if any.
|
||||
if ret == nil || dnsCache.LookupIPFallback == nil {
|
||||
return
|
||||
}
|
||||
ips, err := dnsCache.LookupIPFallback(ctx, host)
|
||||
if err != nil {
|
||||
// Return with original error
|
||||
return
|
||||
}
|
||||
for _, ip := range ips {
|
||||
dst := net.JoinHostPort(ip.String(), port)
|
||||
if c, err := fwd(ctx, network, dst); err == nil {
|
||||
retConn = c
|
||||
ret = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ip, ip6, err := dnsCache.LookupIP(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve %q: %w", host, err)
|
||||
@ -300,3 +339,62 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
|
||||
return fwd(ctx, network, dst)
|
||||
}
|
||||
}
|
||||
|
||||
var errTLSHandshakeTimeout = errors.New("timeout doing TLS handshake")
|
||||
|
||||
// TLSDialer is like Dialer but returns a func suitable for using with net/http.Transport.DialTLSContext.
|
||||
// It returns a *tls.Conn type on success.
|
||||
// On TLS cert validation failure, it can invoke a backup DNS resolution strategy.
|
||||
func TLSDialer(fwd DialContextFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) DialContextFunc {
|
||||
tcpDialer := Dialer(fwd, dnsCache)
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tcpConn, err := tcpDialer(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := cloneTLSConfig(tlsConfigBase)
|
||||
if cfg.ServerName == "" {
|
||||
cfg.ServerName = host
|
||||
}
|
||||
tlsConn := tls.Client(tcpConn, cfg)
|
||||
|
||||
errc := make(chan error, 2)
|
||||
handshakeCtx, handshakeTimeoutCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer handshakeTimeoutCancel()
|
||||
done := make(chan bool)
|
||||
defer close(done)
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-handshakeCtx.Done():
|
||||
errc <- errTLSHandshakeTimeout
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
err := tlsConn.Handshake()
|
||||
handshakeTimeoutCancel()
|
||||
errc <- err
|
||||
}()
|
||||
if err := <-errc; err != nil {
|
||||
tcpConn.Close()
|
||||
// TODO: if err != errTLSHandshakeTimeout,
|
||||
// assume it might be some captive portal or
|
||||
// otherwise incorrect DNS and try the backup
|
||||
// DNS mechanism.
|
||||
return nil, err
|
||||
}
|
||||
return tlsConn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||
if cfg == nil {
|
||||
return &tls.Config{}
|
||||
}
|
||||
return cfg.Clone()
|
||||
}
|
||||
|
Reference in New Issue
Block a user