feat: add support for basic auth (#16)

Add new CLI parameters to enable protecting the API endpoints with basic
auth authentication.
Wrap the server endpoints in a new auth middleware that protects it using
the provided basic auth credentials (if set).
Store the provided basic auth credentials as hashed values to prevent them
from being accidentally leaked.
Add unit tests to ensure the new functionality works as expected.
This commit is contained in:
Hector 2022-01-14 21:36:49 +00:00
parent 013e8f30c9
commit 6f76a03118
9 changed files with 251 additions and 10 deletions

View file

@ -39,18 +39,22 @@ See the [releases page](https://gitlab.com/hectorjsmith/fail2ban-prometheus-expo
``` ```
$ fail2ban-prometheus-exporter -h $ fail2ban-prometheus-exporter -h
-web.listen-address string -collector.textfile
address to use for metrics server (default 0.0.0.0) enable the textfile collector
-collector.textfile.directory string
directory to read text files with metrics from
-port int -port int
port to use for the metrics server (default 9191) port to use for the metrics server (default 9191)
-socket string -socket string
path to the fail2ban server socket path to the fail2ban server socket
-version -version
show version info and exit show version info and exit
-collector.textfile -web.basic-auth.password string
enable the textfile collector password to use to protect endpoints with basic auth
-collector.textfile.directory string -web.basic-auth.username string
directory to read text files with metrics from username to use to protect endpoints with basic auth
-web.listen-address string
address to use for the metrics server (default "0.0.0.0")
``` ```
**Example** **Example**

18
src/auth/hash.go Normal file
View file

@ -0,0 +1,18 @@
package auth
import (
"crypto/sha256"
"encoding/hex"
)
func Hash(data []byte) []byte {
if len(data) == 0 {
return []byte{}
}
b := sha256.Sum256(data)
return b[:]
}
func HashString(data string) string {
return hex.EncodeToString(Hash([]byte(data)))
}

26
src/auth/hash_test.go Normal file
View file

@ -0,0 +1,26 @@
package auth
import (
"reflect"
"testing"
)
func TestHashString(t *testing.T) {
tests := []struct {
name string
args string
want string
}{
{"Happy path #1", "123", "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"},
{"Happy path #2", "hello world", "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"},
{"Happy path #3", "H3Ll0_W0RLD", "d58a27fe9a6e73a1d8a67189fb8acace047e7a1a795276a0056d3717ad61bd0e"},
{"Blank string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := HashString(tt.args); !reflect.DeepEqual(got, tt.want) {
t.Errorf("HashString() = %v, want %v", got, tt.want)
}
})
}
}

31
src/auth/middleware.go Normal file
View file

@ -0,0 +1,31 @@
package auth
import (
"net/http"
)
type BasicAuthProvider interface {
Enabled() bool
DoesBasicAuthMatch(username, password string) bool
}
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc {
if basicAuthProvider.Enabled() {
return func(w http.ResponseWriter, r *http.Request) {
if doesBasicAuthMatch(r, basicAuthProvider) {
handlerFunc.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusUnauthorized)
}
}
}
return handlerFunc
}
func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
rawUsername, rawPassword, ok := r.BasicAuth()
if ok {
return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword)
}
return false
}

View file

@ -0,0 +1,58 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
)
type testAuthProvider struct {
enabled bool
match bool
}
func (p testAuthProvider) Enabled() bool {
return p.enabled
}
func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool {
return p.match
}
func newTestRequest() *http.Request {
return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
}
func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) {
callCount := 0
testHandler := func(w http.ResponseWriter, r *http.Request) {
callCount++
}
handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches})
recorder := httptest.NewRecorder()
request := newTestRequest()
if authEnabled {
request.SetBasicAuth("test", "test")
}
handler.ServeHTTP(recorder, request)
if recorder.Code != expectedCode {
t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode)
}
if callCount != expectedCallCount {
t.Errorf("callCount = %v, want %v", callCount, expectedCallCount)
}
}
func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1)
}
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) {
executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1)
}
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) {
executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0)
}

25
src/cfg/basicAuth.go Normal file
View file

@ -0,0 +1,25 @@
package cfg
import "fail2ban-prometheus-exporter/auth"
type hashedBasicAuth struct {
username string
password string
}
func newHashedBasicAuth(rawUsername, rawPassword string) *hashedBasicAuth {
return &hashedBasicAuth{
username: auth.HashString(rawUsername),
password: auth.HashString(rawPassword),
}
}
func (p *hashedBasicAuth) Enabled() bool {
return len(p.username) > 0 && len(p.password) > 0
}
func (p *hashedBasicAuth) DoesBasicAuthMatch(rawUsername, rawPassword string) bool {
username := auth.HashString(rawUsername)
password := auth.HashString(rawPassword)
return username == p.username && password == p.password
}

60
src/cfg/basicAuth_test.go Normal file
View file

@ -0,0 +1,60 @@
package cfg
import "testing"
func Test_hashedBasicAuth_DoesBasicAuthMatch(t *testing.T) {
type args struct {
username string
password string
}
type fields struct {
username string
password string
}
tests := []struct {
name string
fields fields
args args
want bool
}{
{"Happy test #1", fields{username: "1234", password: "1234"}, args{username: "1234", password: "1234"}, true},
{"Happy test #2", fields{username: "test", password: "1234"}, args{username: "test", password: "1234"}, true},
{"Happy test #3", fields{username: "TEST", password: "1234"}, args{username: "TEST", password: "1234"}, true},
{"Non match #1", fields{username: "test", password: "1234"}, args{username: "1234", password: "1234"}, false},
{"Non match #2", fields{username: "1234", password: "test"}, args{username: "1234", password: "1234"}, false},
{"Non match #3", fields{username: "1234", password: "test"}, args{username: "1234", password: "TEST"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password)
if got := basicAuth.DoesBasicAuthMatch(tt.args.username, tt.args.password); got != tt.want {
t.Errorf("DoesBasicAuthMatch() = %v, want %v", got, tt.want)
}
})
}
}
func Test_hashedBasicAuth_Enabled(t *testing.T) {
type fields struct {
username string
password string
}
tests := []struct {
name string
fields fields
want bool
}{
{"Both blank", fields{username: "", password: ""}, false},
{"Single blank #1", fields{username: "test", password: ""}, false},
{"Single blank #1", fields{username: "", password: "test"}, false},
{"Both populated", fields{username: "test", password: "test"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password)
if got := basicAuth.Enabled(); got != tt.want {
t.Errorf("Enabled() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -18,9 +18,13 @@ type AppSettings struct {
Fail2BanSocketPath string Fail2BanSocketPath string
FileCollectorPath string FileCollectorPath string
FileCollectorEnabled bool FileCollectorEnabled bool
BasicAuthProvider *hashedBasicAuth
} }
func Parse() *AppSettings { func Parse() *AppSettings {
var rawBasicAuthUsername string
var rawBasicAuthPassword string
appSettings := &AppSettings{} appSettings := &AppSettings{}
flag.BoolVar(&appSettings.VersionMode, "version", false, "show version info and exit") flag.BoolVar(&appSettings.VersionMode, "version", false, "show version info and exit")
flag.StringVar(&appSettings.MetricsAddress, "web.listen-address", "0.0.0.0", "address to use for the metrics server") flag.StringVar(&appSettings.MetricsAddress, "web.listen-address", "0.0.0.0", "address to use for the metrics server")
@ -28,12 +32,19 @@ func Parse() *AppSettings {
flag.StringVar(&appSettings.Fail2BanSocketPath, "socket", "", "path to the fail2ban server socket") flag.StringVar(&appSettings.Fail2BanSocketPath, "socket", "", "path to the fail2ban server socket")
flag.BoolVar(&appSettings.FileCollectorEnabled, "collector.textfile", false, "enable the textfile collector") flag.BoolVar(&appSettings.FileCollectorEnabled, "collector.textfile", false, "enable the textfile collector")
flag.StringVar(&appSettings.FileCollectorPath, "collector.textfile.directory", "", "directory to read text files with metrics from") flag.StringVar(&appSettings.FileCollectorPath, "collector.textfile.directory", "", "directory to read text files with metrics from")
flag.StringVar(&rawBasicAuthUsername, "web.basic-auth.username", "", "username to use to protect endpoints with basic auth")
flag.StringVar(&rawBasicAuthPassword, "web.basic-auth.password", "", "password to use to protect endpoints with basic auth")
flag.Parse() flag.Parse()
appSettings.setBasicAuthValues(rawBasicAuthUsername, rawBasicAuthPassword)
appSettings.validateFlags() appSettings.validateFlags()
return appSettings return appSettings
} }
func (settings *AppSettings) setBasicAuthValues(rawUsername, rawPassword string) {
settings.BasicAuthProvider = newHashedBasicAuth(rawUsername, rawPassword)
}
func (settings *AppSettings) validateFlags() { func (settings *AppSettings) validateFlags() {
var flagsValid = true var flagsValid = true
if !settings.VersionMode { if !settings.VersionMode {
@ -50,6 +61,10 @@ func (settings *AppSettings) validateFlags() {
fmt.Printf("file collector directory path must not be empty if collector enabled\n") fmt.Printf("file collector directory path must not be empty if collector enabled\n")
flagsValid = false flagsValid = false
} }
if (len(settings.BasicAuthProvider.username) > 0) != (len(settings.BasicAuthProvider.password) > 0) {
fmt.Printf("to enable basic auth both the username and the password must be provided")
flagsValid = false
}
} }
if !flagsValid { if !flagsValid {
flag.Usage() flag.Usage()

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"fail2ban-prometheus-exporter/auth"
"fail2ban-prometheus-exporter/cfg" "fail2ban-prometheus-exporter/cfg"
"fail2ban-prometheus-exporter/collector/f2b" "fail2ban-prometheus-exporter/collector/f2b"
"fail2ban-prometheus-exporter/collector/textfile" "fail2ban-prometheus-exporter/collector/textfile"
@ -63,10 +64,13 @@ func main() {
textFileCollector := textfile.NewCollector(appSettings) textFileCollector := textfile.NewCollector(appSettings)
prometheus.MustRegister(textFileCollector) prometheus.MustRegister(textFileCollector)
http.HandleFunc("/", rootHtmlHandler) http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider))
http.HandleFunc(metricsPath, func(w http.ResponseWriter, r *http.Request) { http.HandleFunc(metricsPath, auth.BasicAuthMiddleware(
metricHandler(w, r, textFileCollector) func(w http.ResponseWriter, r *http.Request) {
}) metricHandler(w, r, textFileCollector)
},
appSettings.BasicAuthProvider,
))
log.Printf("metrics available at '%s'", metricsPath) log.Printf("metrics available at '%s'", metricsPath)
svrErr := make(chan error) svrErr := make(chan error)