package agent

import (
	"fmt"
	"io"
	"log"
	"net"
	"strings"
	"sync"
	"time"

	"github.com/armon/go-metrics"
	"github.com/hashicorp/consul/consul"
	"github.com/hashicorp/consul/consul/structs"
	"github.com/miekg/dns"
)

const (
	maxServiceResponses = 3 // For UDP only
	maxRecurseRecords   = 5
)

// DNSServer is used to wrap an Agent and expose various
// service discovery endpoints using a DNS interface.
type DNSServer struct {
	agent        *Agent
	config       *DNSConfig
	dnsHandler   *dns.ServeMux
	dnsServer    *dns.Server
	dnsServerTCP *dns.Server
	domain       string
	recursors    []string
	logger       *log.Logger
}

// Shutdown stops the DNS Servers
func (d *DNSServer) Shutdown() {
	if err := d.dnsServer.Shutdown(); err != nil {
		d.logger.Printf("[ERR] dns: error stopping udp server: %v", err)
	}
	if err := d.dnsServerTCP.Shutdown(); err != nil {
		d.logger.Printf("[ERR] dns: error stopping tcp server: %v", err)
	}
}

// NewDNSServer starts a new DNS server to provide an agent interface
func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain string, bind string, recursors []string) (*DNSServer, error) {
	// Make sure domain is FQDN
	domain = dns.Fqdn(domain)

	// Construct the DNS components
	mux := dns.NewServeMux()

	var wg sync.WaitGroup

	// Setup the servers
	server := &dns.Server{
		Addr:              bind,
		Net:               "udp",
		Handler:           mux,
		UDPSize:           65535,
		NotifyStartedFunc: wg.Done,
	}
	serverTCP := &dns.Server{
		Addr:              bind,
		Net:               "tcp",
		Handler:           mux,
		NotifyStartedFunc: wg.Done,
	}

	// Create the server
	srv := &DNSServer{
		agent:        agent,
		config:       config,
		dnsHandler:   mux,
		dnsServer:    server,
		dnsServerTCP: serverTCP,
		domain:       domain,
		recursors:    recursors,
		logger:       log.New(logOutput, "", log.LstdFlags),
	}

	// Register mux handler, for reverse lookup
	mux.HandleFunc("arpa.", srv.handlePtr)

	// Register mux handlers
	mux.HandleFunc(domain, srv.handleQuery)
	if len(recursors) > 0 {
		validatedRecursors := make([]string, len(recursors))

		for idx, recursor := range recursors {
			recursor, err := recursorAddr(recursor)
			if err != nil {
				return nil, fmt.Errorf("Invalid recursor address: %v", err)
			}
			validatedRecursors[idx] = recursor
		}

		srv.recursors = validatedRecursors
		mux.HandleFunc(".", srv.handleRecurse)
	}

	wg.Add(2)

	// Async start the DNS Servers, handle a potential error
	errCh := make(chan error, 1)
	go func() {
		if err := server.ListenAndServe(); err != nil {
			srv.logger.Printf("[ERR] dns: error starting udp server: %v", err)
			errCh <- fmt.Errorf("dns udp setup failed: %v", err)
		}
	}()

	errChTCP := make(chan error, 1)
	go func() {
		if err := serverTCP.ListenAndServe(); err != nil {
			srv.logger.Printf("[ERR] dns: error starting tcp server: %v", err)
			errChTCP <- fmt.Errorf("dns tcp setup failed: %v", err)
		}
	}()

	// Wait for NotifyStartedFunc callbacks indicating server has started
	startCh := make(chan struct{})
	go func() {
		wg.Wait()
		close(startCh)
	}()

	// Wait for either the check, listen error, or timeout
	select {
	case e := <-errCh:
		return srv, e
	case e := <-errChTCP:
		return srv, e
	case <-startCh:
		return srv, nil
	case <-time.After(time.Second):
		return srv, fmt.Errorf("timeout setting up DNS server")
	}
}

// recursorAddr is used to add a port to the recursor if omitted.
func recursorAddr(recursor string) (string, error) {
	// Add the port if none
START:
	_, _, err := net.SplitHostPort(recursor)
	if ae, ok := err.(*net.AddrError); ok && ae.Err == "missing port in address" {
		recursor = fmt.Sprintf("%s:%d", recursor, 53)
		goto START
	}
	if err != nil {
		return "", err
	}

	// Get the address
	addr, err := net.ResolveTCPAddr("tcp", recursor)
	if err != nil {
		return "", err
	}

	// Return string
	return addr.String(), nil
}

// handlePtr is used to handle "reverse" DNS queries
func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
	q := req.Question[0]
	defer func(s time.Time) {
		metrics.MeasureSince([]string{"consul", "dns", "ptr_query", d.agent.config.NodeName}, s)
		d.logger.Printf("[DEBUG] dns: request for %v (%v) from client %s (%s)",
			q, time.Now().Sub(s), resp.RemoteAddr().String(),
			resp.RemoteAddr().Network())
	}(time.Now())

	// Setup the message response
	m := new(dns.Msg)
	m.SetReply(req)
	m.Authoritative = true
	m.RecursionAvailable = (len(d.recursors) > 0)

	// Only add the SOA if requested
	if req.Question[0].Qtype == dns.TypeSOA {
		d.addSOA(d.domain, m)
	}

	datacenter := d.agent.config.Datacenter

	// Get the QName without the domain suffix
	qName := strings.ToLower(dns.Fqdn(req.Question[0].Name))

	args := structs.DCSpecificRequest{
		Datacenter: datacenter,
		QueryOptions: structs.QueryOptions{
			Token:      d.agent.config.ACLToken,
			AllowStale: d.config.AllowStale,
		},
	}
	var out structs.IndexedNodes

	// TODO: Replace ListNodes with an internal RPC that can do the filter
	// server side to avoid transferring the entire node list.
	if err := d.agent.RPC("Catalog.ListNodes", &args, &out); err == nil {
		for _, n := range out.Nodes {
			arpa, _ := dns.ReverseAddr(n.Address)
			if arpa == qName {
				ptr := &dns.PTR{
					Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0},
					Ptr: fmt.Sprintf("%s.node.%s.%s", n.Node, datacenter, d.domain),
				}
				m.Answer = append(m.Answer, ptr)
				break
			}
		}
	}

	// nothing found locally, recurse
	if len(m.Answer) == 0 {
		d.handleRecurse(resp, req)
		return
	}

	// Write out the complete response
	if err := resp.WriteMsg(m); err != nil {
		d.logger.Printf("[WARN] dns: failed to respond: %v", err)
	}
}

// handleQuery is used to handle DNS queries in the configured domain
func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
	q := req.Question[0]
	defer func(s time.Time) {
		metrics.MeasureSince([]string{"consul", "dns", "domain_query", d.agent.config.NodeName}, s)
		d.logger.Printf("[DEBUG] dns: request for %v (%v) from client %s (%s)",
			q, time.Now().Sub(s), resp.RemoteAddr().String(),
			resp.RemoteAddr().Network())
	}(time.Now())

	// Switch to TCP if the client is
	network := "udp"
	if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
		network = "tcp"
	}

	// Setup the message response
	m := new(dns.Msg)
	m.SetReply(req)
	m.Authoritative = true
	m.RecursionAvailable = (len(d.recursors) > 0)

	// Only add the SOA if requested
	if req.Question[0].Qtype == dns.TypeSOA {
		d.addSOA(d.domain, m)
	}

	// Dispatch the correct handler
	d.dispatch(network, req, m)

	// Write out the complete response
	if err := resp.WriteMsg(m); err != nil {
		d.logger.Printf("[WARN] dns: failed to respond: %v", err)
	}
}

// addSOA is used to add an SOA record to a message for the given domain
func (d *DNSServer) addSOA(domain string, msg *dns.Msg) {
	soa := &dns.SOA{
		Hdr: dns.RR_Header{
			Name:   domain,
			Rrtype: dns.TypeSOA,
			Class:  dns.ClassINET,
			Ttl:    0,
		},
		Ns:      "ns." + domain,
		Mbox:    "postmaster." + domain,
		Serial:  uint32(time.Now().Unix()),
		Refresh: 3600,
		Retry:   600,
		Expire:  86400,
		Minttl:  0,
	}
	msg.Ns = append(msg.Ns, soa)
}

// dispatch is used to parse a request and invoke the correct handler
func (d *DNSServer) dispatch(network string, req, resp *dns.Msg) {
	// By default the query is in the default datacenter
	datacenter := d.agent.config.Datacenter

	// Get the QName without the domain suffix
	qName := strings.ToLower(dns.Fqdn(req.Question[0].Name))
	qName = strings.TrimSuffix(qName, d.domain)

	// Split into the label parts
	labels := dns.SplitDomainName(qName)

	// The last label is either "node", "service", "query", or a datacenter name
PARSE:
	n := len(labels)
	if n == 0 {
		goto INVALID
	}
	switch labels[n-1] {
	case "service":
		if n == 1 {
			goto INVALID
		}

		// Support RFC 2782 style syntax
		if n == 3 && strings.HasPrefix(labels[n-2], "_") && strings.HasPrefix(labels[n-3], "_") {

			// Grab the tag since we make nuke it if it's tcp
			tag := labels[n-2][1:]

			// Treat _name._tcp.service.consul as a default, no need to filter on that tag
			if tag == "tcp" {
				tag = ""
			}

			// _name._tag.service.consul
			d.serviceLookup(network, datacenter, labels[n-3][1:], tag, req, resp)

			// Consul 0.3 and prior format for SRV queries
		} else {

			// Support "." in the label, re-join all the parts
			tag := ""
			if n >= 3 {
				tag = strings.Join(labels[:n-2], ".")
			}

			// tag[.tag].name.service.consul
			d.serviceLookup(network, datacenter, labels[n-2], tag, req, resp)
		}

	case "node":
		if n == 1 {
			goto INVALID
		}

		// Allow a "." in the node name, just join all the parts
		node := strings.Join(labels[:n-1], ".")
		d.nodeLookup(network, datacenter, node, req, resp)

	case "query":
		if n == 1 {
			goto INVALID
		}

		// Allow a "." in the query name, just join all the parts.
		query := strings.Join(labels[:n-1], ".")
		d.preparedQueryLookup(network, datacenter, query, req, resp)

	default:
		// Store the DC, and re-parse
		datacenter = labels[n-1]
		labels = labels[:n-1]
		goto PARSE
	}
	return
INVALID:
	d.logger.Printf("[WARN] dns: QName invalid: %s", qName)
	d.addSOA(d.domain, resp)
	resp.SetRcode(req, dns.RcodeNameError)
}

// translateAddr is used to provide the final, translated address for a node,
// depending on how this agent and the other node are configured.
func (d *DNSServer) translateAddr(dc string, node *structs.Node) string {
	addr := node.Address
	if d.agent.config.TranslateWanAddrs && (d.agent.config.Datacenter != dc) {
		wanAddr := node.TaggedAddresses["wan"]
		if wanAddr != "" {
			addr = wanAddr
		}
	}
	return addr
}

// nodeLookup is used to handle a node query
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg) {
	// Only handle ANY, A and AAAA type requests
	qType := req.Question[0].Qtype
	if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA {
		return
	}

	// Make an RPC request
	args := structs.NodeSpecificRequest{
		Datacenter: datacenter,
		Node:       node,
		QueryOptions: structs.QueryOptions{
			Token:      d.agent.config.ACLToken,
			AllowStale: d.config.AllowStale,
		},
	}
	var out structs.IndexedNodeServices
RPC:
	if err := d.agent.RPC("Catalog.NodeServices", &args, &out); err != nil {
		d.logger.Printf("[ERR] dns: rpc error: %v", err)
		resp.SetRcode(req, dns.RcodeServerFailure)
		return
	}

	// Verify that request is not too stale, redo the request
	if args.AllowStale && out.LastContact > d.config.MaxStale {
		args.AllowStale = false
		d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
		goto RPC
	}

	// If we have no address, return not found!
	if out.NodeServices == nil {
		d.addSOA(d.domain, resp)
		resp.SetRcode(req, dns.RcodeNameError)
		return
	}

	// Add the node record
	addr := d.translateAddr(datacenter, out.NodeServices.Node)
	records := d.formatNodeRecord(out.NodeServices.Node, addr,
		req.Question[0].Name, qType, d.config.NodeTTL)
	if records != nil {
		resp.Answer = append(resp.Answer, records...)
	}
}

// formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration) (records []dns.RR) {
	// Parse the IP
	ip := net.ParseIP(addr)
	var ipv4 net.IP
	if ip != nil {
		ipv4 = ip.To4()
	}
	switch {
	case ipv4 != nil && (qType == dns.TypeANY || qType == dns.TypeA):
		return []dns.RR{&dns.A{
			Hdr: dns.RR_Header{
				Name:   qName,
				Rrtype: dns.TypeA,
				Class:  dns.ClassINET,
				Ttl:    uint32(ttl / time.Second),
			},
			A: ip,
		}}

	case ip != nil && ipv4 == nil && (qType == dns.TypeANY || qType == dns.TypeAAAA):
		return []dns.RR{&dns.AAAA{
			Hdr: dns.RR_Header{
				Name:   qName,
				Rrtype: dns.TypeAAAA,
				Class:  dns.ClassINET,
				Ttl:    uint32(ttl / time.Second),
			},
			AAAA: ip,
		}}

	case ip == nil && (qType == dns.TypeANY || qType == dns.TypeCNAME ||
		qType == dns.TypeA || qType == dns.TypeAAAA):
		// Get the CNAME
		cnRec := &dns.CNAME{
			Hdr: dns.RR_Header{
				Name:   qName,
				Rrtype: dns.TypeCNAME,
				Class:  dns.ClassINET,
				Ttl:    uint32(ttl / time.Second),
			},
			Target: dns.Fqdn(addr),
		}
		records = append(records, cnRec)

		// Recurse
		more := d.resolveCNAME(cnRec.Target)
		extra := 0
	MORE_REC:
		for _, rr := range more {
			switch rr.Header().Rrtype {
			case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA:
				records = append(records, rr)
				extra++
				if extra == maxRecurseRecords {
					break MORE_REC
				}
			}
		}
	}
	return records
}

// trimAnswers makes sure a UDP response is not longer than allowed by RFC 1035.
// We first enforce an arbitrary limit, and then make sure the response doesn't
// exceed 512 bytes.
func trimAnswers(resp *dns.Msg) (trimmed bool) {
	numAnswers := len(resp.Answer)

	// This cuts UDP responses to a useful but limited number of responses.
	if numAnswers > maxServiceResponses {
		resp.Answer = resp.Answer[:maxServiceResponses]
	}

	// This enforces the hard limit of 512 bytes per the RFC.
	for len(resp.Answer) > 0 && resp.Len() > 512 {
		resp.Answer = resp.Answer[:len(resp.Answer)-1]
	}

	return len(resp.Answer) < numAnswers
}

// serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, resp *dns.Msg) {
	// Make an RPC request
	args := structs.ServiceSpecificRequest{
		Datacenter:  datacenter,
		ServiceName: service,
		ServiceTag:  tag,
		TagFilter:   tag != "",
		QueryOptions: structs.QueryOptions{
			Token:      d.agent.config.ACLToken,
			AllowStale: d.config.AllowStale,
		},
	}
	var out structs.IndexedCheckServiceNodes
RPC:
	if err := d.agent.RPC("Health.ServiceNodes", &args, &out); err != nil {
		d.logger.Printf("[ERR] dns: rpc error: %v", err)
		resp.SetRcode(req, dns.RcodeServerFailure)
		return
	}

	// Verify that request is not too stale, redo the request
	if args.AllowStale && out.LastContact > d.config.MaxStale {
		args.AllowStale = false
		d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
		goto RPC
	}

	// Determine the TTL
	var ttl time.Duration
	if d.config.ServiceTTL != nil {
		var ok bool
		ttl, ok = d.config.ServiceTTL[service]
		if !ok {
			ttl = d.config.ServiceTTL["*"]
		}
	}

	// Filter out any service nodes due to health checks
	out.Nodes = out.Nodes.Filter(d.config.OnlyPassing)

	// If we have no nodes, return not found!
	if len(out.Nodes) == 0 {
		d.addSOA(d.domain, resp)
		resp.SetRcode(req, dns.RcodeNameError)
		return
	}

	// Perform a random shuffle
	out.Nodes.Shuffle()

	// Add various responses depending on the request
	qType := req.Question[0].Qtype
	d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl)

	if qType == dns.TypeSRV {
		d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl)
	}

	// If the network is not TCP, restrict the number of responses
	if network != "tcp" {
		wasTrimmed := trimAnswers(resp)

		// Flag that there are more records to return in the UDP response
		if wasTrimmed && d.config.EnableTruncate {
			resp.Truncated = true
		}
	}

	// If the answer is empty and the response isn't truncated, return not found
	if len(resp.Answer) == 0 && !resp.Truncated {
		d.addSOA(d.domain, resp)
		return
	}
}

// preparedQueryLookup is used to handle a prepared query.
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, req, resp *dns.Msg) {
	// Execute the prepared query.
	args := structs.PreparedQueryExecuteRequest{
		Datacenter:    datacenter,
		QueryIDOrName: query,
		QueryOptions: structs.QueryOptions{
			Token:      d.agent.config.ACLToken,
			AllowStale: d.config.AllowStale,
		},
	}

	// TODO (slackpad) - What's a safe limit we can set here? It seems like
	// with dup filtering done at this level we need to get everything to
	// match the previous behavior. We can optimize by pushing more filtering
	// into the query execution, but for now I think we need to get the full
	// response. We could also choose a large arbitrary number that will
	// likely work in practice, like 10*maxServiceResponses which should help
	// reduce bandwidth if there are thousands of nodes available.

	endpoint := d.agent.getEndpoint(preparedQueryEndpoint)
	var out structs.PreparedQueryExecuteResponse
RPC:
	if err := d.agent.RPC(endpoint+".Execute", &args, &out); err != nil {
		// If they give a bogus query name, treat that as a name error,
		// not a full on server error. We have to use a string compare
		// here since the RPC layer loses the type information.
		if err.Error() == consul.ErrQueryNotFound.Error() {
			d.addSOA(d.domain, resp)
			resp.SetRcode(req, dns.RcodeNameError)
			return
		}

		d.logger.Printf("[ERR] dns: rpc error: %v", err)
		resp.SetRcode(req, dns.RcodeServerFailure)
		return
	}

	// Verify that request is not too stale, redo the request.
	if args.AllowStale && out.LastContact > d.config.MaxStale {
		args.AllowStale = false
		d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
		goto RPC
	}

	// Determine the TTL. The parse should never fail since we vet it when
	// the query is created, but we check anyway. If the query didn't
	// specify a TTL then we will try to use the agent's service-specific
	// TTL configs.
	var ttl time.Duration
	if out.DNS.TTL != "" {
		var err error
		ttl, err = time.ParseDuration(out.DNS.TTL)
		if err != nil {
			d.logger.Printf("[WARN] dns: Failed to parse TTL '%s' for prepared query '%s', ignoring", out.DNS.TTL, query)
		}
	} else if d.config.ServiceTTL != nil {
		var ok bool
		ttl, ok = d.config.ServiceTTL[out.Service]
		if !ok {
			ttl = d.config.ServiceTTL["*"]
		}
	}

	// If we have no nodes, return not found!
	if len(out.Nodes) == 0 {
		d.addSOA(d.domain, resp)
		resp.SetRcode(req, dns.RcodeNameError)
		return
	}

	// Add various responses depending on the request.
	qType := req.Question[0].Qtype
	d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl)
	if qType == dns.TypeSRV {
		d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl)
	}

	// If the network is not TCP, restrict the number of responses.
	if network != "tcp" {
		wasTrimmed := trimAnswers(resp)

		// Flag that there are more records to return in the UDP response
		if wasTrimmed && d.config.EnableTruncate {
			resp.Truncated = true
		}
	}

	// If the answer is empty and the response isn't truncated, return not found
	if len(resp.Answer) == 0 && !resp.Truncated {
		d.addSOA(d.domain, resp)
		return
	}
}

// serviceNodeRecords is used to add the node records for a service lookup
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
	qName := req.Question[0].Name
	qType := req.Question[0].Qtype
	handled := make(map[string]struct{})
	for _, node := range nodes {
		// Start with the translated address but use the service address,
		// if specified.
		addr := d.translateAddr(dc, node.Node)
		if node.Service.Address != "" {
			addr = node.Service.Address
		}

		// Avoid duplicate entries, possible if a node has
		// the same service on multiple ports, etc.
		if _, ok := handled[addr]; ok {
			continue
		}
		handled[addr] = struct{}{}

		// Add the node record
		records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl)
		if records != nil {
			resp.Answer = append(resp.Answer, records...)
		}
	}
}

// serviceARecords is used to add the SRV records for a service lookup
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
	handled := make(map[string]struct{})
	for _, node := range nodes {
		// Avoid duplicate entries, possible if a node has
		// the same service the same port, etc.
		tuple := fmt.Sprintf("%s:%s:%d", node.Node.Node, node.Service.Address, node.Service.Port)
		if _, ok := handled[tuple]; ok {
			continue
		}
		handled[tuple] = struct{}{}

		// Add the SRV record
		srvRec := &dns.SRV{
			Hdr: dns.RR_Header{
				Name:   req.Question[0].Name,
				Rrtype: dns.TypeSRV,
				Class:  dns.ClassINET,
				Ttl:    uint32(ttl / time.Second),
			},
			Priority: 1,
			Weight:   1,
			Port:     uint16(node.Service.Port),
			Target:   fmt.Sprintf("%s.node.%s.%s", node.Node.Node, dc, d.domain),
		}
		resp.Answer = append(resp.Answer, srvRec)

		// Start with the translated address but use the service address,
		// if specified.
		addr := d.translateAddr(dc, node.Node)
		if node.Service.Address != "" {
			addr = node.Service.Address
		}

		// Add the extra record
		records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl)
		if records != nil {
			resp.Extra = append(resp.Extra, records...)
		}
	}
}

// handleRecurse is used to handle recursive DNS queries
func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
	q := req.Question[0]
	network := "udp"
	defer func(s time.Time) {
		d.logger.Printf("[DEBUG] dns: request for %v (%s) (%v) from client %s (%s)",
			q, network, time.Now().Sub(s), resp.RemoteAddr().String(),
			resp.RemoteAddr().Network())
	}(time.Now())

	// Switch to TCP if the client is
	if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
		network = "tcp"
	}

	// Recursively resolve
	c := &dns.Client{Net: network}
	var r *dns.Msg
	var rtt time.Duration
	var err error
	for _, recursor := range d.recursors {
		r, rtt, err = c.Exchange(req, recursor)
		if err == nil {
			// Forward the response
			d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt)
			if err := resp.WriteMsg(r); err != nil {
				d.logger.Printf("[WARN] dns: failed to respond: %v", err)
			}
			return
		}
		d.logger.Printf("[ERR] dns: recurse failed: %v", err)
	}

	// If all resolvers fail, return a SERVFAIL message
	d.logger.Printf("[ERR] dns: all resolvers failed for %v from client %s (%s)",
		q, resp.RemoteAddr().String(), resp.RemoteAddr().Network())
	m := &dns.Msg{}
	m.SetReply(req)
	m.RecursionAvailable = true
	m.SetRcode(req, dns.RcodeServerFailure)
	resp.WriteMsg(m)
}

// resolveCNAME is used to recursively resolve CNAME records
func (d *DNSServer) resolveCNAME(name string) []dns.RR {
	// Do nothing if we don't have a recursor
	if len(d.recursors) == 0 {
		return nil
	}

	// Ask for any A records
	m := new(dns.Msg)
	m.SetQuestion(name, dns.TypeA)

	// Make a DNS lookup request
	c := &dns.Client{Net: "udp"}
	var r *dns.Msg
	var rtt time.Duration
	var err error
	for _, recursor := range d.recursors {
		r, rtt, err = c.Exchange(m, recursor)
		if err == nil {
			d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt)
			return r.Answer
		}
		d.logger.Printf("[ERR] dns: cname recurse failed for %v: %v", name, err)
	}
	d.logger.Printf("[ERR] dns: all resolvers failed for %v", name)
	return nil
}
