Skip to content

Commit

Permalink
Add context to some DNS utils; export a couple functions
Browse files Browse the repository at this point in the history
Newly exported functions are marked as experimental since I may refactor or unexport their API again.
  • Loading branch information
mholt committed Feb 20, 2025
1 parent b24a7ba commit a7894dd
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 39 deletions.
68 changes: 42 additions & 26 deletions dnsutil.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package certmagic

import (
"context"
"errors"
"fmt"
"net"
Expand All @@ -18,21 +19,24 @@ import (
//
// It has been modified.

// findZoneByFQDN determines the zone apex for the given fqdn by recursing
// up the domain labels until the nameserver returns a SOA record in the
// answer section. The logger must be non-nil.
func findZoneByFQDN(logger *zap.Logger, fqdn string, nameservers []string) (string, error) {
// FindZoneByFQDN determines the zone apex for the given fully-qualified
// domain name (FQDN) by recursing up the domain labels until the nameserver
// returns a SOA record in the answer section. The logger must be non-nil.
//
// EXPERIMENTAL: This API was previously unexported, and may be changed or
// unexported again in the future. Do not rely on it at this time.
func FindZoneByFQDN(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (string, error) {
if !strings.HasSuffix(fqdn, ".") {
fqdn += "."
}
soa, err := lookupSoaByFqdn(logger, fqdn, nameservers)
soa, err := lookupSoaByFqdn(ctx, logger, fqdn, nameservers)
if err != nil {
return "", err
}
return soa.zone, nil
}

func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
func lookupSoaByFqdn(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
logger = logger.Named("soa_lookup")

if !strings.HasSuffix(fqdn, ".") {
Expand All @@ -42,13 +46,17 @@ func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*so
fqdnSOACacheMu.Lock()
defer fqdnSOACacheMu.Unlock()

if err := ctx.Err(); err != nil {
return nil, err
}

// prefer cached version if fresh
if ent := fqdnSOACache[fqdn]; ent != nil && !ent.isExpired() {
logger.Debug("using cached SOA result", zap.String("entry", ent.zone))
return ent, nil
}

ent, err := fetchSoaByFqdn(logger, fqdn, nameservers)
ent, err := fetchSoaByFqdn(ctx, logger, fqdn, nameservers)
if err != nil {
return nil, err
}
Expand All @@ -66,15 +74,19 @@ func lookupSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*so
return ent, nil
}

func fetchSoaByFqdn(logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
func fetchSoaByFqdn(ctx context.Context, logger *zap.Logger, fqdn string, nameservers []string) (*soaCacheEntry, error) {
var err error
var in *dns.Msg

labelIndexes := dns.Split(fqdn)
for _, index := range labelIndexes {
if err := ctx.Err(); err != nil {
return nil, err
}

domain := fqdn[index:]

in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
in, err = dnsQuery(ctx, domain, dns.TypeSOA, nameservers, true)
if err != nil {
continue
}
Expand Down Expand Up @@ -122,12 +134,12 @@ func dnsMsgContainsCNAME(msg *dns.Msg) bool {
return false
}

func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
func dnsQuery(ctx context.Context, fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
m := createDNSMsg(fqdn, rtype, recursive)
var in *dns.Msg
var err error
for _, ns := range nameservers {
in, err = sendDNSQuery(m, ns)
in, err = sendDNSQuery(ctx, m, ns)
if err == nil && len(in.Answer) > 0 {
break
}
Expand All @@ -147,16 +159,16 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
return m
}

func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
func sendDNSQuery(ctx context.Context, m *dns.Msg, ns string) (*dns.Msg, error) {
udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
in, _, err := udp.Exchange(m, ns)
in, _, err := udp.ExchangeContext(ctx, m, ns)
// two kinds of errors we can handle by retrying with TCP:
// truncation and timeout; see https://github.com/caddyserver/caddy/issues/3639
truncated := in != nil && in.Truncated
timeoutErr := err != nil && strings.Contains(err.Error(), "timeout")
if truncated || timeoutErr {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
in, _, err = tcp.Exchange(m, ns)
in, _, err = tcp.ExchangeContext(ctx, m, ns)
}
return in, err
}
Expand Down Expand Up @@ -205,7 +217,8 @@ func systemOrDefaultNameservers(path string, defaults []string) []string {
return config.Servers
}

// populateNameserverPorts ensures that all nameservers have a port number.
// populateNameserverPorts ensures that all nameservers have a port number
// If not, the the default DNS server port of 53 will be appended.
func populateNameserverPorts(servers []string) {
for i := range servers {
_, port, _ := net.SplitHostPort(servers[i])
Expand All @@ -216,7 +229,7 @@ func populateNameserverPorts(servers []string) {
}

// checkDNSPropagation checks if the expected record has been propagated to all authoritative nameservers.
func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expectedValue string, checkAuthoritativeServers bool, resolvers []string) (bool, error) {
func checkDNSPropagation(ctx context.Context, logger *zap.Logger, fqdn string, recType uint16, expectedValue string, checkAuthoritativeServers bool, resolvers []string) (bool, error) {
logger = logger.Named("propagation")

if !strings.HasSuffix(fqdn, ".") {
Expand All @@ -227,7 +240,7 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect
// dereference (follow) a CNAME record if we are targeting a CNAME record
// itself
if recType != dns.TypeCNAME {
r, err := dnsQuery(fqdn, recType, resolvers, true)
r, err := dnsQuery(ctx, fqdn, recType, resolvers, true)
if err != nil {
return false, fmt.Errorf("CNAME dns query: %v", err)
}
Expand All @@ -237,7 +250,7 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect
}

if checkAuthoritativeServers {
authoritativeServers, err := lookupNameservers(logger, fqdn, resolvers)
authoritativeServers, err := lookupNameservers(ctx, logger, fqdn, resolvers)
if err != nil {
return false, fmt.Errorf("looking up authoritative nameservers: %v", err)
}
Expand All @@ -246,13 +259,13 @@ func checkDNSPropagation(logger *zap.Logger, fqdn string, recType uint16, expect
}
logger.Debug("checking authoritative nameservers", zap.Strings("resolvers", resolvers))

return checkAuthoritativeNss(fqdn, recType, expectedValue, resolvers)
return checkAuthoritativeNss(ctx, fqdn, recType, expectedValue, resolvers)
}

// checkAuthoritativeNss queries each of the given nameservers for the expected record.
func checkAuthoritativeNss(fqdn string, recType uint16, expectedValue string, nameservers []string) (bool, error) {
func checkAuthoritativeNss(ctx context.Context, fqdn string, recType uint16, expectedValue string, nameservers []string) (bool, error) {
for _, ns := range nameservers {
r, err := dnsQuery(fqdn, recType, []string{ns}, true)
r, err := dnsQuery(ctx, fqdn, recType, []string{ns}, true)
if err != nil {
return false, fmt.Errorf("querying authoritative nameservers: %v", err)
}
Expand Down Expand Up @@ -293,15 +306,15 @@ func checkAuthoritativeNss(fqdn string, recType uint16, expectedValue string, na
}

// lookupNameservers returns the authoritative nameservers for the given fqdn.
func lookupNameservers(logger *zap.Logger, fqdn string, resolvers []string) ([]string, error) {
func lookupNameservers(ctx context.Context, logger *zap.Logger, fqdn string, resolvers []string) ([]string, error) {
var authoritativeNss []string

zone, err := findZoneByFQDN(logger, fqdn, resolvers)
zone, err := FindZoneByFQDN(ctx, logger, fqdn, resolvers)
if err != nil {
return nil, fmt.Errorf("could not determine the zone for '%s': %w", fqdn, err)
}

r, err := dnsQuery(zone, dns.TypeNS, resolvers, true)
r, err := dnsQuery(ctx, zone, dns.TypeNS, resolvers, true)
if err != nil {
return nil, fmt.Errorf("querying NS resolver for zone '%s' recursively: %v", zone, err)
}
Expand Down Expand Up @@ -330,11 +343,14 @@ func updateDomainWithCName(r *dns.Msg, fqdn string) string {
return fqdn
}

// recursiveNameservers are used to pre-check DNS propagation. It
// RecursiveNameservers are used to pre-check DNS propagation. It
// picks user-configured nameservers (custom) OR the defaults
// obtained from resolv.conf and defaultNameservers if none is
// configured and ensures that all server addresses have a port value.
func recursiveNameservers(custom []string) []string {
//
// EXPERIMENTAL: This API was previously unexported, and may be
// be unexported again in the future. Do not rely on it at this time.
func RecursiveNameservers(custom []string) []string {
var servers []string
if len(custom) == 0 {
servers = systemOrDefaultNameservers(defaultResolvConf, defaultNameservers)
Expand Down
21 changes: 11 additions & 10 deletions dnsutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package certmagic
// It has been modified.

import (
"context"
"net"
"reflect"
"runtime"
Expand Down Expand Up @@ -34,7 +35,7 @@ func TestLookupNameserversOK(t *testing.T) {
t.Run(test.fqdn, func(t *testing.T) {
t.Parallel()

nss, err := lookupNameservers(zap.NewNop(), test.fqdn, recursiveNameservers(nil))
nss, err := lookupNameservers(context.Background(), zap.NewNop(), test.fqdn, RecursiveNameservers(nil))
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
Expand Down Expand Up @@ -68,7 +69,7 @@ func TestLookupNameserversErr(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()

_, err := lookupNameservers(zap.NewNop(), test.fqdn, nil)
_, err := lookupNameservers(context.Background(), zap.NewNop(), test.fqdn, nil)
if err == nil {
t.Errorf("expected error, got none")
}
Expand All @@ -93,28 +94,28 @@ var findXByFqdnTestCases = []struct {
fqdn: "scholar.google.com.",
zone: "google.com.",
primaryNs: "ns1.google.com.",
nameservers: recursiveNameservers(nil),
nameservers: RecursiveNameservers(nil),
},
{
desc: "domain is a non-existent subdomain",
fqdn: "foo.google.com.",
zone: "google.com.",
primaryNs: "ns1.google.com.",
nameservers: recursiveNameservers(nil),
nameservers: RecursiveNameservers(nil),
},
{
desc: "domain is a eTLD",
fqdn: "example.com.ac.",
zone: "ac.",
primaryNs: "a0.nic.ac.",
nameservers: recursiveNameservers(nil),
nameservers: RecursiveNameservers(nil),
},
{
desc: "domain is a cross-zone CNAME",
fqdn: "cross-zone-example.assets.sh.",
zone: "assets.sh.",
primaryNs: "gina.ns.cloudflare.com.",
nameservers: recursiveNameservers(nil),
nameservers: RecursiveNameservers(nil),
},
{
desc: "NXDOMAIN",
Expand Down Expand Up @@ -160,7 +161,7 @@ func TestFindZoneByFqdn(t *testing.T) {
}
clearFqdnCache()

zone, err := findZoneByFQDN(zap.NewNop(), test.fqdn, test.nameservers)
zone, err := FindZoneByFQDN(context.Background(), zap.NewNop(), test.fqdn, test.nameservers)
if test.expectedError != "" {
if err == nil {
t.Errorf("test %d: expected error, got none", i)
Expand Down Expand Up @@ -219,7 +220,7 @@ func TestRecursiveNameserversAddsPort(t *testing.T) {
}
custom := []string{"127.0.0.1", "ns1.google.com:43"}
expectations := []want{{port: "53"}, {port: "43"}}
results := recursiveNameservers(custom)
results := RecursiveNameservers(custom)

if !reflect.DeepEqual(custom, []string{"127.0.0.1", "ns1.google.com:43"}) {
t.Errorf("Expected custom nameservers to be unmodified. got %v", custom)
Expand Down Expand Up @@ -247,12 +248,12 @@ func TestRecursiveNameserversAddsPort(t *testing.T) {
}

func TestRecursiveNameserversDefaults(t *testing.T) {
results := recursiveNameservers(nil)
results := RecursiveNameservers(nil)
if len(results) < 1 {
t.Errorf("%v Expected at least 1 records as default when nil custom", results)
}

results = recursiveNameservers([]string{})
results = RecursiveNameservers([]string{})
if len(results) < 1 {
t.Errorf("%v Expected at least 1 records as default when empty custom", results)
}
Expand Down
6 changes: 3 additions & 3 deletions solvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ type DNSManager struct {
func (m *DNSManager) createRecord(ctx context.Context, dnsName, recordType, recordValue string) (zoneRecord, error) {
logger := m.logger()

zone, err := findZoneByFQDN(logger, dnsName, recursiveNameservers(m.Resolvers))
zone, err := FindZoneByFQDN(ctx, logger, dnsName, RecursiveNameservers(m.Resolvers))
if err != nil {
return zoneRecord{}, fmt.Errorf("could not determine zone for domain %q: %v", dnsName, err)
}
Expand Down Expand Up @@ -439,7 +439,7 @@ func (m *DNSManager) wait(ctx context.Context, zrec zoneRecord) error {

// how we'll do the checks
checkAuthoritativeServers := len(m.Resolvers) == 0
resolvers := recursiveNameservers(m.Resolvers)
resolvers := RecursiveNameservers(m.Resolvers)

recType := dns.TypeTXT
if zrec.record.Type == "CNAME" {
Expand All @@ -464,7 +464,7 @@ func (m *DNSManager) wait(ctx context.Context, zrec zoneRecord) error {
zap.Strings("resolvers", resolvers))

var ready bool
ready, err = checkDNSPropagation(logger, absName, recType, zrec.record.Value, checkAuthoritativeServers, resolvers)
ready, err = checkDNSPropagation(ctx, logger, absName, recType, zrec.record.Value, checkAuthoritativeServers, resolvers)
if err != nil {
return fmt.Errorf("checking DNS propagation of %q (relative=%s zone=%s resolvers=%v): %w", absName, zrec.record.Name, zrec.zone, resolvers, err)
}
Expand Down

0 comments on commit a7894dd

Please sign in to comment.