control/controlclient, net/{dnscache,dnsfallback}: add DNS fallback mechanism

Updates #1405
Updates #1403

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2021-02-26 12:49:54 -08:00
committed by Brad Fitzpatrick
parent 03c344333e
commit 9df4185c94
4 changed files with 211 additions and 6 deletions

View File

@ -8,6 +8,8 @@ package dnscache
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
@ -18,6 +20,7 @@ import (
"time"
"golang.org/x/sync/singleflight"
"inet.af/netaddr"
)
var single = &Resolver{
@ -55,6 +58,10 @@ type Resolver struct {
// If nil, net.DefaultResolver is used.
Forward *net.Resolver
// LookupIPFallback optionally provides a backup DNS mechanism
// to use if Forward returns an error or no results.
LookupIPFallback func(ctx context.Context, host string) ([]netaddr.IP, error)
// TTL is how long to keep entries cached
//
// If zero, a default (currently 10 minutes) is used.
@ -198,6 +205,18 @@ func (r *Resolver) lookupIP(host string) (ip, ip6 net.IP, err error) {
ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host))
defer cancel()
ips, err := r.fwd().LookupIPAddr(ctx, host)
if (err != nil || len(ips) == 0) && r.LookupIPFallback != nil {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var fips []netaddr.IP
fips, err = r.LookupIPFallback(ctx, host)
if err == nil {
ips = nil
for _, fip := range fips {
ips = append(ips, *fip.IPAddr())
}
}
}
if err != nil {
return nil, nil, err
}
@ -269,13 +288,33 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
return func(ctx context.Context, network, address string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (retConn net.Conn, ret error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
// Bogus. But just let the real dialer return an error rather than
// inventing a similar one.
return fwd(ctx, network, address)
}
defer func() {
// On any failure, assume our DNS is wrong and try our fallback, if any.
if ret == nil || dnsCache.LookupIPFallback == nil {
return
}
ips, err := dnsCache.LookupIPFallback(ctx, host)
if err != nil {
// Return with original error
return
}
for _, ip := range ips {
dst := net.JoinHostPort(ip.String(), port)
if c, err := fwd(ctx, network, dst); err == nil {
retConn = c
ret = nil
return
}
}
}()
ip, ip6, err := dnsCache.LookupIP(ctx, host)
if err != nil {
return nil, fmt.Errorf("failed to resolve %q: %w", host, err)
@ -300,3 +339,62 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
return fwd(ctx, network, dst)
}
}
var errTLSHandshakeTimeout = errors.New("timeout doing TLS handshake")
// TLSDialer is like Dialer but returns a func suitable for using with net/http.Transport.DialTLSContext.
// It returns a *tls.Conn type on success.
// On TLS cert validation failure, it can invoke a backup DNS resolution strategy.
func TLSDialer(fwd DialContextFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) DialContextFunc {
tcpDialer := Dialer(fwd, dnsCache)
return func(ctx context.Context, network, address string) (net.Conn, error) {
host, _, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
tcpConn, err := tcpDialer(ctx, network, address)
if err != nil {
return nil, err
}
cfg := cloneTLSConfig(tlsConfigBase)
if cfg.ServerName == "" {
cfg.ServerName = host
}
tlsConn := tls.Client(tcpConn, cfg)
errc := make(chan error, 2)
handshakeCtx, handshakeTimeoutCancel := context.WithTimeout(ctx, 5*time.Second)
defer handshakeTimeoutCancel()
done := make(chan bool)
defer close(done)
go func() {
select {
case <-done:
case <-handshakeCtx.Done():
errc <- errTLSHandshakeTimeout
}
}()
go func() {
err := tlsConn.Handshake()
handshakeTimeoutCancel()
errc <- err
}()
if err := <-errc; err != nil {
tcpConn.Close()
// TODO: if err != errTLSHandshakeTimeout,
// assume it might be some captive portal or
// otherwise incorrect DNS and try the backup
// DNS mechanism.
return nil, err
}
return tlsConn, nil
}
}
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}