diff --git a/config.cfg b/config.cfg index 4e90348..5cf6580 100644 --- a/config.cfg +++ b/config.cfg @@ -1,4 +1,6 @@ [general] +# dns interface +listen = ":53" # domain name to serve th requests off of domain = "auth.example.org" # zone name server diff --git a/dns_test.go b/dns_test.go index 0d37975..93e8990 100644 --- a/dns_test.go +++ b/dns_test.go @@ -3,7 +3,6 @@ package main import ( "errors" "fmt" - log "github.com/Sirupsen/logrus" "github.com/miekg/dns" "strings" "testing" @@ -74,22 +73,8 @@ func findRecordFromMemory(rrstr string, host string, qtype uint16) error { return errors.New(errmsg) } -func startDNSServer(addr string) (*dns.Server, resolver) { - - // DNS server part - dns.HandleFunc(".", handleRequest) - server := &dns.Server{Addr: addr, Net: "udp"} - go func() { - err := server.ListenAndServe() - if err != nil { - log.Errorf("%v", err) - } - }() - return server, resolver{server: addr} -} - func TestResolveA(t *testing.T) { - setupConfig() + resolv := resolver{server: "0.0.0.0:15353"} answer, err := resolv.lookup("auth.example.org", dns.TypeA) if err != nil { t.Errorf("%v", err) @@ -107,8 +92,7 @@ func TestResolveA(t *testing.T) { } func TestResolveTXT(t *testing.T) { - setupConfig() - + resolv := resolver{server: "0.0.0.0:15353"} validTXT := "______________valid_response_______________" atxt, err := DB.Register() diff --git a/main.go b/main.go index 4a02d49..c49882c 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,6 @@ import ( log "github.com/Sirupsen/logrus" "github.com/iris-contrib/middleware/cors" "github.com/kataras/iris" - "github.com/miekg/dns" "os" ) @@ -27,7 +26,7 @@ func main() { } DNSConf = configTmp - setupLogging() + setupLogging(DNSConf.Logconfig.Format, DNSConf.Logconfig.Level) // Read the default records in RR.Parse(DNSConf.General.StaticRecords) @@ -40,16 +39,8 @@ func main() { } defer DB.DB.Close() - // DNS server part - dns.HandleFunc(".", handleRequest) - server := &dns.Server{Addr: ":53", Net: "udp"} - go func() { - err = server.ListenAndServe() - if err != nil { - log.Errorf("%v", err) - os.Exit(1) - } - }() + // DNS server + startDNS(DNSConf.General.Listen) // API server and endpoints api := iris.New() diff --git a/main_test.go b/main_test.go index 795b5a0..0f98dab 100644 --- a/main_test.go +++ b/main_test.go @@ -28,7 +28,7 @@ func TestMain(m *testing.M) { _ = DB.Init("sqlite3", ":memory:") } - server, resolv = startDNSServer("0.0.0.0:15353") + server := startDNS("0.0.0.0:15353") exitval := m.Run() server.Shutdown() DB.DB.Close() diff --git a/types.go b/types.go index 91dec01..e6cfb90 100644 --- a/types.go +++ b/types.go @@ -23,6 +23,7 @@ type authMiddleware struct{} // Config file general section type general struct { + Listen string Domain string Nsname string Nsadmin string diff --git a/util.go b/util.go index 3a49e7b..282ba99 100644 --- a/util.go +++ b/util.go @@ -6,8 +6,10 @@ import ( "fmt" "github.com/BurntSushi/toml" log "github.com/Sirupsen/logrus" + "github.com/miekg/dns" "github.com/satori/go.uuid" "math/big" + "os" "regexp" "strings" ) @@ -22,27 +24,20 @@ func readConfig(fname string) (DNSConfig, error) { func sanitizeString(s string) string { // URL safe base64 alphabet without padding as defined in ACME - re, err := regexp.Compile("[^A-Za-z\\-\\_0-9]+") - if err != nil { - log.Errorf("%v", err) - return "" - } + re, _ := regexp.Compile("[^A-Za-z\\-\\_0-9]+") return re.ReplaceAllString(s, "") } -func generatePassword(length int) (string, error) { +func generatePassword(length int) string { ret := make([]byte, length) const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890-_" alphalen := big.NewInt(int64(len(alphabet))) for i := 0; i < length; i++ { - c, err := rand.Int(rand.Reader, alphalen) - if err != nil { - return "", err - } + c, _ := rand.Int(rand.Reader, alphalen) r := int(c.Int64()) ret[i] = alphabet[r] } - return string(ret), nil + return string(ret) } func sanitizeDomainQuestion(d string) string { @@ -57,17 +52,14 @@ func sanitizeDomainQuestion(d string) string { func newACMETxt() (ACMETxt, error) { var a = ACMETxt{} - password, err := generatePassword(40) - if err != nil { - return a, err - } + password := generatePassword(40) a.Username = uuid.NewV4() a.Password = password a.Subdomain = uuid.NewV4().String() return a, nil } -func setupLogging() { +func setupLogging(format string, level string) { if DNSConf.Logconfig.Format == "json" { log.SetFormatter(&log.JSONFormatter{}) } @@ -83,3 +75,17 @@ func setupLogging() { } // TODO: file logging } + +func startDNS(listen string) *dns.Server { + // DNS server part + dns.HandleFunc(".", handleRequest) + server := &dns.Server{Addr: listen, Net: "udp"} + go func() { + err := server.ListenAndServe() + if err != nil { + log.Errorf("%v", err) + os.Exit(1) + } + }() + return server +}