diff --git a/api.go b/api.go index d0acf62..0188eb9 100644 --- a/api.go +++ b/api.go @@ -46,11 +46,13 @@ func WebRegisterPost(ctx *iris.Context) { errstr := fmt.Sprintf("%v", err) regJSON = iris.Map{"username": "", "password": "", "domain": "", "error": errstr} regStatus = iris.StatusInternalServerError + log.Debugf("Error in registration, [%v]", err) } else { regJSON = iris.Map{"username": nu.Username, "password": nu.Password, "fulldomain": nu.Subdomain + "." + DNSConf.General.Domain, "subdomain": nu.Subdomain} regStatus = iris.StatusCreated + + log.Debugf("Successful registration, created user [%s]", nu.Username) } - log.Debugf("Successful registration, created user [%s]", nu.Username) ctx.JSON(regStatus, regJSON) } diff --git a/config.cfg b/config.cfg index 4e03122..01dab02 100644 --- a/config.cfg +++ b/config.cfg @@ -19,6 +19,12 @@ records = [ # debug messages from CORS etc debug = false +[database] +# Database engine to use, sqlite3 or postgres +engine = "sqlite3" +# Connection string, filename for sqlite3 and postgres://$username:$password@$host/$db_name for postgres +connection = "acme-dns.db" +# connection = "postgres://user:password@localhost/acmedns_db" [api] # domain name to listen requests for, mandatory if using tls = "letsencrypt" @@ -33,7 +39,7 @@ tls_cert_privkey = "/etc/tls/example.org/privkey.pem" tls_cert_fullchain = "/etc/tls/example.org/fullchain.pem" # CORS AllowOrigins, wildcards can be used corsorigins = [ - "web.example.org" + "*" ] [logconfig] diff --git a/db.go b/db.go index 2e18629..1d01deb 100644 --- a/db.go +++ b/db.go @@ -3,9 +3,12 @@ package main import ( "database/sql" "errors" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/satori/go.uuid" "golang.org/x/crypto/bcrypt" + "regexp" + "time" ) type Database struct { @@ -18,11 +21,21 @@ var recordsTable = ` Password TEXT UNIQUE NOT NULL, Subdomain TEXT UNIQUE NOT NULL, Value TEXT, - LastActive DATETIME + LastActive INT );` -func (d *Database) Init(filename string) error { - db, err := sql.Open("sqlite3", filename) +// getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?" +func getSQLiteStmt(s string) string { + re, err := regexp.Compile("\\$[0-9]") + if err != nil { + log.Errorf("%v", err) + return s + } + return re.ReplaceAllString(s, "?") +} + +func (d *Database) Init(engine string, connection string) error { + db, err := sql.Open(engine, connection) if err != nil { return err } @@ -40,20 +53,24 @@ func (d *Database) Register() (ACMETxt, error) { return ACMETxt{}, err } passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10) + timenow := time.Now().Unix() regSQL := ` INSERT INTO records( Username, Password, Subdomain, - Value, - LastActive) - values(?, ?, ?, ?, CURRENT_TIMESTAMP)` + Value, + LastActive) + values($1, $2, $3, '', $4)` + if DNSConf.Database.Engine == "sqlite3" { + regSQL = getSQLiteStmt(regSQL) + } sm, err := d.DB.Prepare(regSQL) if err != nil { return a, err } defer sm.Close() - _, err = sm.Exec(a.Username, passwordHash, a.Subdomain, a.Value) + _, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow) if err != nil { return a, err } @@ -65,8 +82,12 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) { getSQL := ` SELECT Username, Password, Subdomain, Value, LastActive FROM records - WHERE Username=? LIMIT 1 + WHERE Username=$1 LIMIT 1 ` + if DNSConf.Database.Engine == "sqlite3" { + getSQL = getSQLiteStmt(getSQL) + } + sm, err := d.DB.Prepare(getSQL) if err != nil { return ACMETxt{}, err @@ -105,8 +126,12 @@ func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) { getSQL := ` SELECT Username, Password, Subdomain, Value FROM records - WHERE Subdomain=? LIMIT 1 + WHERE Subdomain=$1 LIMIT 1 ` + if DNSConf.Database.Engine == "sqlite3" { + getSQL = getSQLiteStmt(getSQL) + } + sm, err := d.DB.Prepare(getSQL) if err != nil { return a, err @@ -132,16 +157,21 @@ func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) { func (d *Database) Update(a ACMETxt) error { // Data in a is already sanitized log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value) + timenow := time.Now().Unix() updSQL := ` - UPDATE records SET Value=? - WHERE Username=? AND Subdomain=? + UPDATE records SET Value=$1, LastActive=$2 + WHERE Username=$3 AND Subdomain=$4 ` + if DNSConf.Database.Engine == "sqlite3" { + updSQL = getSQLiteStmt(updSQL) + } + sm, err := d.DB.Prepare(updSQL) if err != nil { return err } defer sm.Close() - _, err = sm.Exec(a.Value, a.Username, a.Subdomain) + _, err = sm.Exec(a.Value, timenow, a.Username, a.Subdomain) if err != nil { return err } diff --git a/db_test.go b/db_test.go index e7662df..a4da146 100644 --- a/db_test.go +++ b/db_test.go @@ -1,11 +1,27 @@ package main import ( + "flag" "testing" ) +var ( + postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL") +) + func TestRegister(t *testing.T) { - _ = DB.Init(":memory:") + flag.Parse() + if *postgres { + DNSConf.Database.Engine = "postgres" + err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") + if err != nil { + t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") + return + } + } else { + DNSConf.Database.Engine = "sqlite3" + _ = DB.Init("sqlite3", ":memory:") + } defer DB.DB.Close() // Register tests @@ -16,7 +32,18 @@ func TestRegister(t *testing.T) { } func TestGetByUsername(t *testing.T) { - _ = DB.Init(":memory:") + flag.Parse() + if *postgres { + DNSConf.Database.Engine = "postgres" + err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") + if err != nil { + t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") + return + } + } else { + DNSConf.Database.Engine = "sqlite3" + _ = DB.Init("sqlite3", ":memory:") + } defer DB.DB.Close() // Create reg to refer to @@ -45,7 +72,18 @@ func TestGetByUsername(t *testing.T) { } func TestGetByDomain(t *testing.T) { - _ = DB.Init(":memory:") + flag.Parse() + if *postgres { + DNSConf.Database.Engine = "postgres" + err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") + if err != nil { + t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") + return + } + } else { + DNSConf.Database.Engine = "sqlite3" + _ = DB.Init("sqlite3", ":memory:") + } defer DB.DB.Close() var regDomain = ACMETxt{} @@ -87,7 +125,18 @@ func TestGetByDomain(t *testing.T) { } func TestUpdate(t *testing.T) { - _ = DB.Init(":memory:") + flag.Parse() + if *postgres { + DNSConf.Database.Engine = "postgres" + err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") + if err != nil { + t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") + return + } + } else { + DNSConf.Database.Engine = "sqlite3" + _ = DB.Init("sqlite3", ":memory:") + } defer DB.DB.Close() // Create reg to refer to diff --git a/main.go b/main.go index b730c6e..2fb71eb 100644 --- a/main.go +++ b/main.go @@ -62,7 +62,7 @@ func main() { RR.Parse(DNSConf.General.StaticRecords) // Open database - err = DB.Init("acme-dns.db") + err = DB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection) if err != nil { log.Errorf("Could not open database [%v]", err) os.Exit(1) diff --git a/types.go b/types.go index 5d98d23..105a1bc 100644 --- a/types.go +++ b/types.go @@ -3,7 +3,6 @@ package main import ( "github.com/miekg/dns" "github.com/satori/go.uuid" - "time" ) // Static records @@ -14,6 +13,7 @@ type Records struct { // Config file main struct type DNSConfig struct { General general + Database dbsettings API httpapi Logconfig logconfig } @@ -30,6 +30,11 @@ type general struct { StaticRecords []string `toml:"records"` } +type dbsettings struct { + Engine string + Connection string +} + // API config type httpapi struct { Domain string @@ -53,7 +58,7 @@ type ACMETxt struct { Username uuid.UUID Password string ACMETxtPost - LastActive time.Time + LastActive int64 } type ACMETxtPost struct {