From 733245fb3d42af22ed5230d1ea651e67cc79a4af Mon Sep 17 00:00:00 2001
From: Joona Hoikkala <joohoi@users.noreply.github.com>
Date: Mon, 22 Jan 2018 09:53:07 +0200
Subject: [PATCH] Support for multiple TXT records per subdomain (#29)
* Support for multiple TXT records per subdomain and database upgrade functionality
* Linter fixes
* Make sure the database upgrade routine works for PostgreSQL
* Move subdomain query outside of the upgrade transaction
---
acmetxt.go | 3 +-
db.go | 193 ++++++++++++++++++++++++++++++++++++++++++++---------
db_test.go | 62 ++++++++---------
dns.go | 8 +--
main.go | 2 +
types.go | 2 +-
6 files changed, 199 insertions(+), 71 deletions(-)
diff --git a/acmetxt.go b/acmetxt.go
index d4dbab0..8584017 100644
--- a/acmetxt.go
+++ b/acmetxt.go
@@ -12,8 +12,7 @@ type ACMETxt struct {
Username uuid.UUID
Password string
ACMETxtPost
- LastActive int64
- AllowFrom cidrslice
+ AllowFrom cidrslice
}
// ACMETxtPost holds the DNS part of the ACMETxt struct
diff --git a/db.go b/db.go
index 9b6b0a0..d245808 100644
--- a/db.go
+++ b/db.go
@@ -4,7 +4,9 @@ import (
"database/sql"
"encoding/json"
"errors"
+ "fmt"
"regexp"
+ "strconv"
"time"
_ "github.com/lib/pq"
@@ -14,16 +16,38 @@ import (
"golang.org/x/crypto/bcrypt"
)
-var recordsTable = `
+// DBVersion shows the database version this code uses. This is used for update checks.
+var DBVersion = 1
+
+var acmeTable = `
+ CREATE TABLE IF NOT EXISTS acmedns(
+ Name TEXT,
+ Value TEXT
+ );`
+
+var userTable = `
CREATE TABLE IF NOT EXISTS records(
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
Password TEXT UNIQUE NOT NULL,
Subdomain TEXT UNIQUE NOT NULL,
- Value TEXT,
- LastActive INT,
AllowFrom TEXT
);`
+var txtTable = `
+ CREATE TABLE IF NOT EXISTS txt(
+ Subdomain TEXT NOT NULL,
+ Value TEXT NOT NULL DEFAULT '',
+ LastUpdate INT
+ );`
+
+var txtTablePG = `
+ CREATE TABLE IF NOT EXISTS txt(
+ rowid SERIAL,
+ Subdomain TEXT NOT NULL,
+ Value TEXT NOT NULL DEFAULT '',
+ LastUpdate INT
+ );`
+
// getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?"
func getSQLiteStmt(s string) string {
re, _ := regexp.Compile("\\$[0-9]")
@@ -38,44 +62,151 @@ func (d *acmedb) Init(engine string, connection string) error {
return err
}
d.DB = db
- //d.DB.SetMaxOpenConns(1)
- _, err = d.DB.Exec(recordsTable)
+ // Check version first to try to catch old versions without version string
+ var versionString string
+ _ = d.DB.QueryRow("SELECT Value FROM acmedns WHERE Name='db_version'").Scan(&versionString)
+ if versionString == "" {
+ versionString = "0"
+ }
+ _, err = d.DB.Exec(acmeTable)
+ _, err = d.DB.Exec(userTable)
+ if Config.Database.Engine == "sqlite3" {
+ _, err = d.DB.Exec(txtTable)
+ } else {
+ _, err = d.DB.Exec(txtTablePG)
+ }
+ // If everything is fine, handle db upgrade tasks
+ if err == nil {
+ err = d.checkDBUpgrades(versionString)
+ }
+ if err == nil {
+ if versionString == "0" {
+ // No errors so we should now be in version 1
+ insversion := fmt.Sprintf("INSERT INTO acmedns (Name, Value) values('db_version', '%d')", DBVersion)
+ _, err = db.Exec(insversion)
+ }
+ }
+ return err
+}
+
+func (d *acmedb) checkDBUpgrades(versionString string) error {
+ var err error
+ version, err := strconv.Atoi(versionString)
if err != nil {
return err
}
+ if version != DBVersion {
+ return d.handleDBUpgrades(version)
+ }
+ return nil
+
+}
+
+func (d *acmedb) handleDBUpgrades(version int) error {
+ if version == 0 {
+ return d.handleDBUpgradeTo1()
+ }
return nil
}
+func (d *acmedb) handleDBUpgradeTo1() error {
+ var err error
+ var subdomains []string
+ rows, err := d.DB.Query("SELECT Subdomain FROM records")
+ if err != nil {
+ log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade")
+ return err
+ }
+ defer rows.Close()
+ for rows.Next() {
+ var subdomain string
+ err = rows.Scan(&subdomain)
+ if err != nil {
+ log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while reading values")
+ return err
+ }
+ subdomains = append(subdomains, subdomain)
+ }
+ err = rows.Err()
+ if err != nil {
+ log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values")
+ return err
+ }
+ tx, err := d.DB.Begin()
+ // Rollback if errored, commit if not
+ defer func() {
+ if err != nil {
+ tx.Rollback()
+ return
+ }
+ tx.Commit()
+ }()
+ _, _ = tx.Exec("DELETE FROM txt")
+ for _, subdomain := range subdomains {
+ if subdomain != "" {
+ // Insert two rows for each subdomain to txt table
+ err = d.NewTXTValuesInTransaction(tx, subdomain)
+ if err != nil {
+ log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values")
+ return err
+ }
+ }
+ }
+ // SQLite doesn't support dropping columns
+ if Config.Database.Engine != "sqlite3" {
+ _, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS Value")
+ _, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS LastActive")
+ }
+ _, err = tx.Exec("UPDATE acmedns SET Value='1' WHERE Name='db_version'")
+ return err
+}
+
+// Create two rows for subdomain to the txt table
+func (d *acmedb) NewTXTValuesInTransaction(tx *sql.Tx, subdomain string) error {
+ var err error
+ instr := fmt.Sprintf("INSERT INTO txt (Subdomain, LastUpdate) values('%s', 0)", subdomain)
+ _, err = tx.Exec(instr)
+ _, err = tx.Exec(instr)
+ return err
+}
+
func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) {
d.Lock()
defer d.Unlock()
+ var err error
+ tx, err := d.DB.Begin()
+ // Rollback if errored, commit if not
+ defer func() {
+ if err != nil {
+ tx.Rollback()
+ return
+ }
+ tx.Commit()
+ }()
a := newACMETxt()
a.AllowFrom = cidrslice(afrom.ValidEntries())
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
- timenow := time.Now().Unix()
regSQL := `
INSERT INTO records(
Username,
Password,
Subdomain,
- Value,
- LastActive,
AllowFrom)
- values($1, $2, $3, '', $4, $5)`
+ values($1, $2, $3, $4)`
if Config.Database.Engine == "sqlite3" {
regSQL = getSQLiteStmt(regSQL)
}
- sm, err := d.DB.Prepare(regSQL)
+ sm, err := tx.Prepare(regSQL)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare")
return a, errors.New("SQL error")
}
defer sm.Close()
- _, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom.JSON())
- if err != nil {
- return a, err
+ _, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, a.AllowFrom.JSON())
+ if err == nil {
+ err = d.NewTXTValuesInTransaction(tx, a.Subdomain)
}
- return a, nil
+ return a, err
}
func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
@@ -83,7 +214,7 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
defer d.Unlock()
var results []ACMETxt
getSQL := `
- SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom
+ SELECT Username, Password, Subdomain, AllowFrom
FROM records
WHERE Username=$1 LIMIT 1
`
@@ -116,15 +247,13 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
return ACMETxt{}, errors.New("no user")
}
-func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
+func (d *acmedb) GetTXTForDomain(domain string) ([]string, error) {
d.Lock()
defer d.Unlock()
domain = sanitizeString(domain)
- var a []ACMETxt
+ var txts []string
getSQL := `
- SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom
- FROM records
- WHERE Subdomain=$1 LIMIT 1
+ SELECT Value FROM txt WHERE Subdomain=$1 LIMIT 2
`
if Config.Database.Engine == "sqlite3" {
getSQL = getSQLiteStmt(getSQL)
@@ -132,33 +261,37 @@ func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
sm, err := d.DB.Prepare(getSQL)
if err != nil {
- return a, err
+ return txts, err
}
defer sm.Close()
rows, err := sm.Query(domain)
if err != nil {
- return a, err
+ return txts, err
}
defer rows.Close()
for rows.Next() {
- txt, err := getModelFromRow(rows)
+ var rtxt string
+ err = rows.Scan(&rtxt)
if err != nil {
- return a, err
+ return txts, err
}
- a = append(a, txt)
+ txts = append(txts, rtxt)
}
- return a, nil
+ return txts, nil
}
func (d *acmedb) Update(a ACMETxt) error {
d.Lock()
defer d.Unlock()
+ var err error
// Data in a is already sanitized
timenow := time.Now().Unix()
+
updSQL := `
- UPDATE records SET Value=$1, LastActive=$2
- WHERE Username=$3 AND Subdomain=$4
+ UPDATE txt SET Value=$1, LastUpdate=$2
+ WHERE rowid=(
+ SELECT rowid FROM txt WHERE Subdomain=$3 ORDER BY LastUpdate LIMIT 1)
`
if Config.Database.Engine == "sqlite3" {
updSQL = getSQLiteStmt(updSQL)
@@ -169,7 +302,7 @@ func (d *acmedb) Update(a ACMETxt) error {
return err
}
defer sm.Close()
- _, err = sm.Exec(a.Value, timenow, a.Username, a.Subdomain)
+ _, err = sm.Exec(a.Value, timenow, a.Subdomain)
if err != nil {
return err
}
@@ -183,8 +316,6 @@ func getModelFromRow(r *sql.Rows) (ACMETxt, error) {
&txt.Username,
&txt.Password,
&txt.Subdomain,
- &txt.Value,
- &txt.LastActive,
&afrom)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error")
diff --git a/db_test.go b/db_test.go
index 4580a34..3cb265c 100644
--- a/db_test.go
+++ b/db_test.go
@@ -118,7 +118,7 @@ func TestPrepareErrors(t *testing.T) {
t.Errorf("Expected error, but didn't get one")
}
- _, err = DB.GetByDomain(reg.Subdomain)
+ _, err = DB.GetTXTForDomain(reg.Subdomain)
if err == nil {
t.Errorf("Expected error, but didn't get one")
}
@@ -151,7 +151,7 @@ func TestQueryExecErrors(t *testing.T) {
t.Errorf("Expected error from exec, but got none")
}
- _, err = DB.GetByDomain(reg.Subdomain)
+ _, err = DB.GetTXTForDomain(reg.Subdomain)
if err == nil {
t.Errorf("Expected error from exec in GetByDomain, but got none")
}
@@ -195,11 +195,6 @@ func TestQueryScanErrors(t *testing.T) {
if err == nil {
t.Errorf("Expected error from scan in, but got none")
}
-
- _, err = DB.GetByDomain(reg.Subdomain)
- if err == nil {
- t.Errorf("Expected error from scan in GetByDomain, but got none")
- }
}
func TestBadDBValues(t *testing.T) {
@@ -226,46 +221,55 @@ func TestBadDBValues(t *testing.T) {
t.Errorf("Expected error from scan in, but got none")
}
- _, err = DB.GetByDomain(reg.Subdomain)
+ _, err = DB.GetTXTForDomain(reg.Subdomain)
if err == nil {
t.Errorf("Expected error from scan in GetByDomain, but got none")
}
}
-func TestGetByDomain(t *testing.T) {
- var regDomain = ACMETxt{}
-
+func TestGetTXTForDomain(t *testing.T) {
// Create reg to refer to
reg, err := DB.Register(cidrslice{})
if err != nil {
t.Errorf("Registration failed, got error [%v]", err)
}
- regDomainSlice, err := DB.GetByDomain(reg.Subdomain)
+ txtval1 := "___validation_token_received_from_the_ca___"
+ txtval2 := "___validation_token_received_YEAH_the_ca___"
+
+ reg.Value = txtval1
+ _ = DB.Update(reg)
+
+ reg.Value = txtval2
+ _ = DB.Update(reg)
+
+ regDomainSlice, err := DB.GetTXTForDomain(reg.Subdomain)
if err != nil {
t.Errorf("Could not get test user, got error [%v]", err)
}
if len(regDomainSlice) == 0 {
- t.Errorf("No rows returned for GetByDomain [%s]", reg.Subdomain)
- } else {
- regDomain = regDomainSlice[0]
+ t.Errorf("No rows returned for GetTXTForDomain [%s]", reg.Subdomain)
}
- if reg.Username != regDomain.Username {
- t.Errorf("GetByUsername username [%q] did not match the original [%q]", regDomain.Username, reg.Username)
+ var val1found = false
+ var val2found = false
+ for _, v := range regDomainSlice {
+ if v == txtval1 {
+ val1found = true
+ }
+ if v == txtval2 {
+ val2found = true
+ }
}
-
- if reg.Subdomain != regDomain.Subdomain {
- t.Errorf("GetByUsername subdomain [%q] did not match the original [%q]", regDomain.Subdomain, reg.Subdomain)
+ if !val1found {
+ t.Errorf("No TXT value found for val1")
}
-
- // regDomain password already is a bcrypt hash
- if !correctPassword(reg.Password, regDomain.Password) {
- t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regDomain.Password)
+ if !val2found {
+ t.Errorf("No TXT value found for val2")
}
// Not found
- regNotfound, _ := DB.GetByDomain("does-not-exist")
+ regNotfound, _ := DB.GetTXTForDomain("does-not-exist")
if len(regNotfound) > 0 {
t.Errorf("No records should be returned.")
}
@@ -294,12 +298,4 @@ func TestUpdate(t *testing.T) {
if err != nil {
t.Errorf("DB Update failed, got error: [%v]", err)
}
-
- updUser, err := DB.GetByUsername(regUser.Username)
- if err != nil {
- t.Errorf("GetByUsername threw error [%v]", err)
- }
- if updUser.Value != validTXT {
- t.Errorf("Update failed, fetched value [%s] does not match the update value [%s]", updUser.Value, validTXT)
- }
}
diff --git a/dns.go b/dns.go
index 122c5d7..3e687a4 100644
--- a/dns.go
+++ b/dns.go
@@ -2,8 +2,8 @@ package main
import (
"fmt"
- log "github.com/sirupsen/logrus"
"github.com/miekg/dns"
+ log "github.com/sirupsen/logrus"
"strings"
"time"
)
@@ -23,16 +23,16 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) {
var ra []dns.RR
rcode := dns.RcodeNameError
subdomain := sanitizeDomainQuestion(q.Name)
- atxt, err := DB.GetByDomain(subdomain)
+ atxt, err := DB.GetTXTForDomain(subdomain)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
return ra, dns.RcodeNameError, err
}
for _, v := range atxt {
- if len(v.Value) > 0 {
+ if len(v) > 0 {
r := new(dns.TXT)
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
- r.Txt = append(r.Txt, v.Value)
+ r.Txt = append(r.Txt, v)
ra = append(ra, r)
rcode = dns.RcodeSuccess
}
diff --git a/main.go b/main.go
index c521d56..036818b 100644
--- a/main.go
+++ b/main.go
@@ -36,6 +36,8 @@ func main() {
if err != nil {
log.Errorf("Could not open database [%v]", err)
os.Exit(1)
+ } else {
+ log.Info("Connected to database")
}
DB = newDB
defer DB.Close()
diff --git a/types.go b/types.go
index 173ab52..d6b6054 100644
--- a/types.go
+++ b/types.go
@@ -79,7 +79,7 @@ type database interface {
Init(string, string) error
Register(cidrslice) (ACMETxt, error)
GetByUsername(uuid.UUID) (ACMETxt, error)
- GetByDomain(string) ([]ACMETxt, error)
+ GetTXTForDomain(string) ([]string, error)
Update(ACMETxt) error
GetBackend() *sql.DB
SetBackend(*sql.DB)
--
GitLab