net/dns/resolver: add Windows ExitDNS service support, using net package

Updates #1713
Updates #835

Change-Id: Ia71e96d0632c2d617b401695ad68301b07c1c2ec
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2021-12-09 12:01:19 -08:00
committed by Brad Fitzpatrick
parent cab5c46481
commit cced414c7d
4 changed files with 775 additions and 9 deletions

View File

@ -360,7 +360,8 @@ func (r *Resolver) HandleExitNodeDNSQuery(ctx context.Context, q []byte, from ne
case "windows":
// TODO: use DnsQueryEx and write to ch.
// See https://docs.microsoft.com/en-us/windows/win32/api/windns/nf-windns-dnsqueryex.
return nil, errors.New("TODO: windows exit node suport")
// For now just use the net package:
return handleExitNodeDNSQueryWithNetPkg(ctx, nil, resp)
case "darwin":
// /etc/resolv.conf is a lie and only says one upstream DNS
// but for now that's probably good enough. Later we'll
@ -404,6 +405,106 @@ func (r *Resolver) HandleExitNodeDNSQuery(ctx context.Context, q []byte, from ne
}
}
// handleExitNodeDNSQueryWithNetPkg takes a DNS query message in q and
// return a reply (for the ExitDNS DoH service) using the net package's
// native APIs. This is only used on Windows for now.
//
// If resolver is nil, the net.Resolver zero value is used.
//
// response contains the pre-serialized response, which notably
// includes the original question and its header.
func handleExitNodeDNSQueryWithNetPkg(ctx context.Context, resolver *net.Resolver, resp *response) (res []byte, err error) {
if resp.Question.Class != dns.ClassINET {
return nil, errors.New("unsupported class")
}
r := resolver
if r == nil {
r = new(net.Resolver)
}
name := resp.Question.Name.String()
handleError := func(err error) (res []byte, _ error) {
if isGoNoSuchHostError(err) {
resp.Header.RCode = dns.RCodeNameError
return marshalResponse(resp)
}
// TODO: map other errors to RCodeServerFailure?
// Or I guess our caller should do that?
return nil, err
}
resp.Header.RCode = dns.RCodeSuccess // unless changed below
switch resp.Question.Type {
case dns.TypeA, dns.TypeAAAA:
network := "ip4"
if resp.Question.Type == dns.TypeAAAA {
network = "ip6"
}
ips, err := r.LookupIP(ctx, network, name)
if err != nil {
return handleError(err)
}
for _, stdIP := range ips {
if ip, ok := netaddr.FromStdIP(stdIP); ok {
resp.IPs = append(resp.IPs, ip)
}
}
case dns.TypeTXT:
strs, err := r.LookupTXT(ctx, name)
if err != nil {
return handleError(err)
}
resp.TXT = strs
case dns.TypePTR:
ipStr, ok := unARPA(name)
if !ok {
// TODO: is this RCodeFormatError?
return nil, errors.New("bogus PTR name")
}
addrs, err := r.LookupAddr(ctx, ipStr)
if err != nil {
return handleError(err)
}
if len(addrs) > 0 {
resp.Name, _ = dnsname.ToFQDN(addrs[0])
}
case dns.TypeCNAME:
cname, err := r.LookupCNAME(ctx, name)
if err != nil {
return handleError(err)
}
resp.CNAME = cname
case dns.TypeSRV:
// Thanks, Go: "To accommodate services publishing SRV
// records under non-standard names, if both service
// and proto are empty strings, LookupSRV looks up
// name directly."
_, srvs, err := r.LookupSRV(ctx, "", "", name)
if err != nil {
return handleError(err)
}
resp.SRVs = srvs
case dns.TypeNS:
nss, err := r.LookupNS(ctx, name)
if err != nil {
return handleError(err)
}
resp.NSs = nss
default:
return nil, fmt.Errorf("unsupported record type %v", resp.Question.Type)
}
return marshalResponse(resp)
}
func isGoNoSuchHostError(err error) bool {
if de, ok := err.(*net.DNSError); ok {
return de.IsNotFound
}
return false
}
type resolvConfCache struct {
mod time.Time
size int64
@ -604,10 +705,27 @@ func (r *Resolver) handleQuery(pkt packet) {
type response struct {
Header dns.Header
Question dns.Question
// Name is the response to a PTR query.
Name dnsname.FQDN
// IP is the response to an A, AAAA, or ALL query.
IP netaddr.IP
// IP and IPs are the responses to an A, AAAA, or ALL query.
// Either/both/neither can be populated.
IP netaddr.IP
IPs []netaddr.IP
// TXT is the response to a TXT query.
// Each one is its own RR with one string.
TXT []string
// CNAME is the response to a CNAME query.
CNAME string
// SRVs are the responses to a SRV query.
SRVs []*net.SRV
// NSs are the responses to an NS query.
NSs []*net.NS
}
var dnsParserPool = &sync.Pool{
@ -683,6 +801,16 @@ func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error
return builder.AAAAResource(answerHeader, answer)
}
func marshalIP(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
if ip.Is4() {
return marshalARecord(name, ip, builder)
}
if ip.Is6() {
return marshalAAAARecord(name, ip, builder)
}
return nil
}
// marshalPTRRecord serializes a PTR record into an active builder.
// The caller may continue using the builder following the call.
func marshalPTRRecord(queryName dns.Name, name dnsname.FQDN, builder *dns.Builder) error {
@ -702,6 +830,83 @@ func marshalPTRRecord(queryName dns.Name, name dnsname.FQDN, builder *dns.Builde
return builder.PTRResource(answerHeader, answer)
}
func marshalTXT(queryName dns.Name, txts []string, builder *dns.Builder) error {
for _, txt := range txts {
if err := builder.TXTResource(dns.ResourceHeader{
Name: queryName,
Type: dns.TypeTXT,
Class: dns.ClassINET,
TTL: uint32(defaultTTL / time.Second),
}, dns.TXTResource{
TXT: []string{txt},
}); err != nil {
return err
}
}
return nil
}
func marshalCNAME(queryName dns.Name, cname string, builder *dns.Builder) error {
if cname == "" {
return nil
}
name, err := dns.NewName(cname)
if err != nil {
return err
}
return builder.CNAMEResource(dns.ResourceHeader{
Name: queryName,
Type: dns.TypeCNAME,
Class: dns.ClassINET,
TTL: uint32(defaultTTL / time.Second),
}, dns.CNAMEResource{
CNAME: name,
})
}
func marshalNS(queryName dns.Name, nss []*net.NS, builder *dns.Builder) error {
for _, ns := range nss {
name, err := dns.NewName(ns.Host)
if err != nil {
return err
}
err = builder.NSResource(dns.ResourceHeader{
Name: queryName,
Type: dns.TypeNS,
Class: dns.ClassINET,
TTL: uint32(defaultTTL / time.Second),
}, dns.NSResource{NS: name})
if err != nil {
return err
}
}
return nil
}
func marshalSRV(queryName dns.Name, srvs []*net.SRV, builder *dns.Builder) error {
for _, s := range srvs {
srvName, err := dns.NewName(s.Target)
if err != nil {
return err
}
err = builder.SRVResource(dns.ResourceHeader{
Name: queryName,
Type: dns.TypeSRV,
Class: dns.ClassINET,
TTL: uint32(defaultTTL / time.Second),
}, dns.SRVResource{
Target: srvName,
Priority: s.Priority,
Port: s.Port,
Weight: s.Weight,
})
if err != nil {
return err
}
}
return nil
}
// marshalResponse serializes the DNS response into a new buffer.
func marshalResponse(resp *response) ([]byte, error) {
resp.Header.Response = true
@ -712,6 +917,14 @@ func marshalResponse(resp *response) ([]byte, error) {
builder := dns.NewBuilder(nil, resp.Header)
// TODO(bradfitz): I'm not sure why this wasn't enabled
// before, but for now (2021-12-09) enable it at least when
// there's more than 1 record (which was never the case
// before), where it really helps.
if len(resp.IPs) > 1 {
builder.EnableCompression()
}
isSuccess := resp.Header.RCode == dns.RCodeSuccess
if resp.Question.Type != 0 || isSuccess {
@ -738,13 +951,24 @@ func marshalResponse(resp *response) ([]byte, error) {
switch resp.Question.Type {
case dns.TypeA, dns.TypeAAAA, dns.TypeALL:
if resp.IP.Is4() {
err = marshalARecord(resp.Question.Name, resp.IP, &builder)
} else if resp.IP.Is6() {
err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder)
if err := marshalIP(resp.Question.Name, resp.IP, &builder); err != nil {
return nil, err
}
for _, ip := range resp.IPs {
if err := marshalIP(resp.Question.Name, ip, &builder); err != nil {
return nil, err
}
}
case dns.TypePTR:
err = marshalPTRRecord(resp.Question.Name, resp.Name, &builder)
case dns.TypeTXT:
err = marshalTXT(resp.Question.Name, resp.TXT, &builder)
case dns.TypeCNAME:
err = marshalCNAME(resp.Question.Name, resp.CNAME, &builder)
case dns.TypeSRV:
err = marshalSRV(resp.Question.Name, resp.SRVs, &builder)
case dns.TypeNS:
err = marshalNS(resp.Question.Name, resp.NSs, &builder)
}
if err != nil {
return nil, err
@ -926,6 +1150,37 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
return marshalResponse(resp)
}
// unARPA maps from "4.4.8.8.in-addr.arpa." to "8.8.4.4", etc.
func unARPA(a string) (ipStr string, ok bool) {
const suf4 = ".in-addr.arpa."
if strings.HasSuffix(a, suf4) {
s := strings.TrimSuffix(a, suf4)
// Parse and reverse octets.
ip, err := netaddr.ParseIP(s)
if err != nil || !ip.Is4() {
return "", false
}
a4 := ip.As4()
return netaddr.IPv4(a4[3], a4[2], a4[1], a4[0]).String(), true
}
const suf6 = ".ip6.arpa."
if len(a) == len("e.0.0.2.0.0.0.0.0.0.0.0.0.0.0.0.b.0.8.0.a.0.0.4.0.b.8.f.7.0.6.2.ip6.arpa.") &&
strings.HasSuffix(a, suf6) {
var hx [32]byte
var a16 [16]byte
for i := range hx {
hx[31-i] = a[i*2]
if a[i*2+1] != '.' {
return "", false
}
}
hex.Decode(a16[:], hx[:])
return netaddr.IPFrom16(a16).String(), true
}
return "", false
}
var (
metricDNSQueryLocal = clientmetric.NewCounter("dns_query_local")
metricDNSQueryErrorClosed = clientmetric.NewCounter("dns_query_local_error_closed")