From af542b44a9ecbddc6bfbbda6f262439af8bf365c Mon Sep 17 00:00:00 2001
From: Joona Hoikkala <joohoi@users.noreply.github.com>
Date: Fri, 18 Oct 2019 22:24:56 +0300
Subject: [PATCH] Proper EDNS0 (non)support (#188)

* Proper EDNS0 (non)support

* Add changelog entry

* Add EDNS0 tests
---
 README.md   |  3 ++
 dns.go      | 25 ++++++++++++-----
 dns_test.go | 80 +++++++++++++++++++++++++++++++++++++++--------------
 3 files changed, 81 insertions(+), 27 deletions(-)

diff --git a/README.md b/README.md
index 8d0004b..c67b82d 100644
--- a/README.md
+++ b/README.md
@@ -345,6 +345,9 @@ use for the renewal.
 
 ## Changelog
 
+- v0.8
+   - Changed
+      - Fixed: EDNS0 support
 - v0.7.2
    - Changed
       - Fixed: Regression error of not being able to answer to incoming random-case requests.
diff --git a/dns.go b/dns.go
index baf54a8..6bdfc66 100644
--- a/dns.go
+++ b/dns.go
@@ -82,8 +82,24 @@ func (d *DNSServer) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
 	m := new(dns.Msg)
 	m.SetReply(r)
 
-	if r.Opcode == dns.OpcodeQuery {
-		d.readQuery(m)
+	// handle edns0
+	opt := r.IsEdns0()
+	if opt != nil {
+		if opt.Version() != 0 {
+			// Only EDNS0 is standardized
+			m.MsgHdr.Rcode = dns.RcodeBadVers
+			m.SetEdns0(512, false)
+		} else {
+			// We can safely do this as we know that we're not setting other OPT RRs within acme-dns.
+			m.SetEdns0(512, false)
+			if r.Opcode == dns.OpcodeQuery {
+				d.readQuery(m)
+			}
+		}
+	} else {
+		if r.Opcode == dns.OpcodeQuery {
+			d.readQuery(m)
+		}
 	}
 	w.WriteMsg(m)
 }
@@ -107,7 +123,6 @@ func (d *DNSServer) readQuery(m *dns.Msg) {
 			m.Ns = append(m.Ns, d.SOA)
 		}
 	}
-
 }
 
 func (d *DNSServer) getRecord(q dns.Question) ([]dns.RR, error) {
@@ -169,10 +184,6 @@ func (d *DNSServer) answer(q dns.Question) ([]dns.RR, int, bool, error) {
 		// Make sure that we return NOERROR if there were dynamic records for the domain
 		rcode = dns.RcodeSuccess
 	}
-	// Handle EDNS (no support at the moment)
-	if q.Qtype == dns.TypeOPT {
-		return []dns.RR{}, dns.RcodeFormatError, authoritative, nil
-	}
 	log.WithFields(log.Fields{"qtype": dns.TypeToString[q.Qtype], "domain": q.Name, "rcode": dns.RcodeToString[rcode]}).Debug("Answering question for domain")
 	return r, rcode, authoritative, nil
 }
diff --git a/dns_test.go b/dns_test.go
index bd6877c..aca82bb 100644
--- a/dns_test.go
+++ b/dns_test.go
@@ -111,8 +111,48 @@ func TestResolveA(t *testing.T) {
 func TestEDNS(t *testing.T) {
 	resolv := resolver{server: "127.0.0.1:15353"}
 	answer, _ := resolv.lookup("auth.example.org", dns.TypeOPT)
-	if answer.Rcode != dns.RcodeFormatError {
-		t.Errorf("Was expecing FORMERR rcode for OPT query, but got [%s] instead.", dns.RcodeToString[answer.Rcode])
+	if answer.Rcode != dns.RcodeSuccess {
+		t.Errorf("Was expecing NOERROR rcode for OPT query, but got [%s] instead.", dns.RcodeToString[answer.Rcode])
+	}
+}
+
+func TestEDNSA(t *testing.T) {
+	msg := new(dns.Msg)
+	msg.Id = dns.Id()
+	msg.Question = make([]dns.Question, 1)
+	msg.Question[0] = dns.Question{Name: dns.Fqdn("auth.example.org"), Qtype: dns.TypeA, Qclass: dns.ClassINET}
+	// Set EDNS0 with DO=1
+	msg.SetEdns0(512, true)
+	in, err := dns.Exchange(msg, "127.0.0.1:15353")
+	if err != nil {
+		t.Errorf("Error querying the server [%v]", err)
+	}
+	if in != nil && in.Rcode != dns.RcodeSuccess {
+		t.Errorf("Received error from the server [%s]", dns.RcodeToString[in.Rcode])
+	}
+	opt := in.IsEdns0()
+	if opt == nil {
+		t.Errorf("Should have got OPT back")
+	}
+}
+
+func TestEDNSBADVERS(t *testing.T) {
+	msg := new(dns.Msg)
+	msg.Id = dns.Id()
+	msg.Question = make([]dns.Question, 1)
+	msg.Question[0] = dns.Question{Name: dns.Fqdn("auth.example.org"), Qtype: dns.TypeA, Qclass: dns.ClassINET}
+	// Set EDNS0 with version 1
+	o := new(dns.OPT)
+	o.SetVersion(1)
+	o.Hdr.Name = "."
+	o.Hdr.Rrtype = dns.TypeOPT
+	msg.Extra = append(msg.Extra, o)
+	in, err := dns.Exchange(msg, "127.0.0.1:15353")
+	if err != nil {
+		t.Errorf("Error querying the server [%v]", err)
+	}
+	if in != nil && in.Rcode != dns.RcodeBadVers {
+		t.Errorf("Received unexpected rcode from the server [%s]", dns.RcodeToString[in.Rcode])
 	}
 }
 
@@ -220,25 +260,25 @@ func TestResolveTXT(t *testing.T) {
 }
 
 func TestCaseInsensitiveResolveA(t *testing.T) {
-  resolv := resolver{server: "127.0.0.1:15353"}
-  answer, err := resolv.lookup("aUtH.eXAmpLe.org", dns.TypeA)
-  if err != nil {
-    t.Errorf("%v", err)
-  }
-
-  if len(answer.Answer) == 0 {
-    t.Error("No answer for DNS query")
-  }
+	resolv := resolver{server: "127.0.0.1:15353"}
+	answer, err := resolv.lookup("aUtH.eXAmpLe.org", dns.TypeA)
+	if err != nil {
+		t.Errorf("%v", err)
+	}
+
+	if len(answer.Answer) == 0 {
+		t.Error("No answer for DNS query")
+	}
 }
 
 func TestCaseInsensitiveResolveSOA(t *testing.T) {
-  resolv := resolver{server: "127.0.0.1:15353"}
-  answer, _ := resolv.lookup("doesnotexist.aUtH.eXAmpLe.org", dns.TypeSOA)
-  if answer.Rcode != dns.RcodeNameError {
-    t.Errorf("Was expecing NXDOMAIN rcode, but got [%s] instead.", dns.RcodeToString[answer.Rcode])
-  }
-
-  if len(answer.Ns) == 0 {
-    t.Error("No SOA answer for DNS query")
-  }
+	resolv := resolver{server: "127.0.0.1:15353"}
+	answer, _ := resolv.lookup("doesnotexist.aUtH.eXAmpLe.org", dns.TypeSOA)
+	if answer.Rcode != dns.RcodeNameError {
+		t.Errorf("Was expecing NXDOMAIN rcode, but got [%s] instead.", dns.RcodeToString[answer.Rcode])
+	}
+
+	if len(answer.Ns) == 0 {
+		t.Error("No SOA answer for DNS query")
+	}
 }
-- 
GitLab