Merge branch '16-add-basic-auth' into 'main'
Resolve "Add basic auth" Closes #16 See merge request hectorjsmith/fail2ban-prometheus-exporter!53
This commit is contained in:
commit
d92f7f79b6
9 changed files with 251 additions and 10 deletions
16
README.md
16
README.md
|
@ -39,18 +39,22 @@ See the [releases page](https://gitlab.com/hectorjsmith/fail2ban-prometheus-expo
|
||||||
```
|
```
|
||||||
$ fail2ban-prometheus-exporter -h
|
$ fail2ban-prometheus-exporter -h
|
||||||
|
|
||||||
-web.listen-address string
|
-collector.textfile
|
||||||
address to use for metrics server (default 0.0.0.0)
|
enable the textfile collector
|
||||||
|
-collector.textfile.directory string
|
||||||
|
directory to read text files with metrics from
|
||||||
-port int
|
-port int
|
||||||
port to use for the metrics server (default 9191)
|
port to use for the metrics server (default 9191)
|
||||||
-socket string
|
-socket string
|
||||||
path to the fail2ban server socket
|
path to the fail2ban server socket
|
||||||
-version
|
-version
|
||||||
show version info and exit
|
show version info and exit
|
||||||
-collector.textfile
|
-web.basic-auth.password string
|
||||||
enable the textfile collector
|
password to use to protect endpoints with basic auth
|
||||||
-collector.textfile.directory string
|
-web.basic-auth.username string
|
||||||
directory to read text files with metrics from
|
username to use to protect endpoints with basic auth
|
||||||
|
-web.listen-address string
|
||||||
|
address to use for the metrics server (default "0.0.0.0")
|
||||||
```
|
```
|
||||||
|
|
||||||
**Example**
|
**Example**
|
||||||
|
|
18
src/auth/hash.go
Normal file
18
src/auth/hash.go
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Hash(data []byte) []byte {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return []byte{}
|
||||||
|
}
|
||||||
|
b := sha256.Sum256(data)
|
||||||
|
return b[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func HashString(data string) string {
|
||||||
|
return hex.EncodeToString(Hash([]byte(data)))
|
||||||
|
}
|
26
src/auth/hash_test.go
Normal file
26
src/auth/hash_test.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHashString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"Happy path #1", "123", "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"},
|
||||||
|
{"Happy path #2", "hello world", "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"},
|
||||||
|
{"Happy path #3", "H3Ll0_W0RLD", "d58a27fe9a6e73a1d8a67189fb8acace047e7a1a795276a0056d3717ad61bd0e"},
|
||||||
|
{"Blank string", "", ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := HashString(tt.args); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("HashString() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
31
src/auth/middleware.go
Normal file
31
src/auth/middleware.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BasicAuthProvider interface {
|
||||||
|
Enabled() bool
|
||||||
|
DoesBasicAuthMatch(username, password string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc {
|
||||||
|
if basicAuthProvider.Enabled() {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if doesBasicAuthMatch(r, basicAuthProvider) {
|
||||||
|
handlerFunc.ServeHTTP(w, r)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return handlerFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
|
||||||
|
rawUsername, rawPassword, ok := r.BasicAuth()
|
||||||
|
if ok {
|
||||||
|
return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
58
src/auth/middleware_test.go
Normal file
58
src/auth/middleware_test.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testAuthProvider struct {
|
||||||
|
enabled bool
|
||||||
|
match bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p testAuthProvider) Enabled() bool {
|
||||||
|
return p.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool {
|
||||||
|
return p.match
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestRequest() *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) {
|
||||||
|
callCount := 0
|
||||||
|
testHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
callCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches})
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
request := newTestRequest()
|
||||||
|
if authEnabled {
|
||||||
|
request.SetBasicAuth("test", "test")
|
||||||
|
}
|
||||||
|
handler.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if recorder.Code != expectedCode {
|
||||||
|
t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode)
|
||||||
|
}
|
||||||
|
if callCount != expectedCallCount {
|
||||||
|
t.Errorf("callCount = %v, want %v", callCount, expectedCallCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
|
||||||
|
executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) {
|
||||||
|
executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) {
|
||||||
|
executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0)
|
||||||
|
}
|
25
src/cfg/basicAuth.go
Normal file
25
src/cfg/basicAuth.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package cfg
|
||||||
|
|
||||||
|
import "fail2ban-prometheus-exporter/auth"
|
||||||
|
|
||||||
|
type hashedBasicAuth struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHashedBasicAuth(rawUsername, rawPassword string) *hashedBasicAuth {
|
||||||
|
return &hashedBasicAuth{
|
||||||
|
username: auth.HashString(rawUsername),
|
||||||
|
password: auth.HashString(rawPassword),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *hashedBasicAuth) Enabled() bool {
|
||||||
|
return len(p.username) > 0 && len(p.password) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *hashedBasicAuth) DoesBasicAuthMatch(rawUsername, rawPassword string) bool {
|
||||||
|
username := auth.HashString(rawUsername)
|
||||||
|
password := auth.HashString(rawPassword)
|
||||||
|
return username == p.username && password == p.password
|
||||||
|
}
|
60
src/cfg/basicAuth_test.go
Normal file
60
src/cfg/basicAuth_test.go
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
package cfg
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func Test_hashedBasicAuth_DoesBasicAuthMatch(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
type fields struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"Happy test #1", fields{username: "1234", password: "1234"}, args{username: "1234", password: "1234"}, true},
|
||||||
|
{"Happy test #2", fields{username: "test", password: "1234"}, args{username: "test", password: "1234"}, true},
|
||||||
|
{"Happy test #3", fields{username: "TEST", password: "1234"}, args{username: "TEST", password: "1234"}, true},
|
||||||
|
{"Non match #1", fields{username: "test", password: "1234"}, args{username: "1234", password: "1234"}, false},
|
||||||
|
{"Non match #2", fields{username: "1234", password: "test"}, args{username: "1234", password: "1234"}, false},
|
||||||
|
{"Non match #3", fields{username: "1234", password: "test"}, args{username: "1234", password: "TEST"}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password)
|
||||||
|
if got := basicAuth.DoesBasicAuthMatch(tt.args.username, tt.args.password); got != tt.want {
|
||||||
|
t.Errorf("DoesBasicAuthMatch() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_hashedBasicAuth_Enabled(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"Both blank", fields{username: "", password: ""}, false},
|
||||||
|
{"Single blank #1", fields{username: "test", password: ""}, false},
|
||||||
|
{"Single blank #1", fields{username: "", password: "test"}, false},
|
||||||
|
{"Both populated", fields{username: "test", password: "test"}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password)
|
||||||
|
if got := basicAuth.Enabled(); got != tt.want {
|
||||||
|
t.Errorf("Enabled() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,9 +18,13 @@ type AppSettings struct {
|
||||||
Fail2BanSocketPath string
|
Fail2BanSocketPath string
|
||||||
FileCollectorPath string
|
FileCollectorPath string
|
||||||
FileCollectorEnabled bool
|
FileCollectorEnabled bool
|
||||||
|
BasicAuthProvider *hashedBasicAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func Parse() *AppSettings {
|
func Parse() *AppSettings {
|
||||||
|
var rawBasicAuthUsername string
|
||||||
|
var rawBasicAuthPassword string
|
||||||
|
|
||||||
appSettings := &AppSettings{}
|
appSettings := &AppSettings{}
|
||||||
flag.BoolVar(&appSettings.VersionMode, "version", false, "show version info and exit")
|
flag.BoolVar(&appSettings.VersionMode, "version", false, "show version info and exit")
|
||||||
flag.StringVar(&appSettings.MetricsAddress, "web.listen-address", "0.0.0.0", "address to use for the metrics server")
|
flag.StringVar(&appSettings.MetricsAddress, "web.listen-address", "0.0.0.0", "address to use for the metrics server")
|
||||||
|
@ -28,12 +32,19 @@ func Parse() *AppSettings {
|
||||||
flag.StringVar(&appSettings.Fail2BanSocketPath, "socket", "", "path to the fail2ban server socket")
|
flag.StringVar(&appSettings.Fail2BanSocketPath, "socket", "", "path to the fail2ban server socket")
|
||||||
flag.BoolVar(&appSettings.FileCollectorEnabled, "collector.textfile", false, "enable the textfile collector")
|
flag.BoolVar(&appSettings.FileCollectorEnabled, "collector.textfile", false, "enable the textfile collector")
|
||||||
flag.StringVar(&appSettings.FileCollectorPath, "collector.textfile.directory", "", "directory to read text files with metrics from")
|
flag.StringVar(&appSettings.FileCollectorPath, "collector.textfile.directory", "", "directory to read text files with metrics from")
|
||||||
|
flag.StringVar(&rawBasicAuthUsername, "web.basic-auth.username", "", "username to use to protect endpoints with basic auth")
|
||||||
|
flag.StringVar(&rawBasicAuthPassword, "web.basic-auth.password", "", "password to use to protect endpoints with basic auth")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
appSettings.setBasicAuthValues(rawBasicAuthUsername, rawBasicAuthPassword)
|
||||||
appSettings.validateFlags()
|
appSettings.validateFlags()
|
||||||
return appSettings
|
return appSettings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (settings *AppSettings) setBasicAuthValues(rawUsername, rawPassword string) {
|
||||||
|
settings.BasicAuthProvider = newHashedBasicAuth(rawUsername, rawPassword)
|
||||||
|
}
|
||||||
|
|
||||||
func (settings *AppSettings) validateFlags() {
|
func (settings *AppSettings) validateFlags() {
|
||||||
var flagsValid = true
|
var flagsValid = true
|
||||||
if !settings.VersionMode {
|
if !settings.VersionMode {
|
||||||
|
@ -50,6 +61,10 @@ func (settings *AppSettings) validateFlags() {
|
||||||
fmt.Printf("file collector directory path must not be empty if collector enabled\n")
|
fmt.Printf("file collector directory path must not be empty if collector enabled\n")
|
||||||
flagsValid = false
|
flagsValid = false
|
||||||
}
|
}
|
||||||
|
if (len(settings.BasicAuthProvider.username) > 0) != (len(settings.BasicAuthProvider.password) > 0) {
|
||||||
|
fmt.Printf("to enable basic auth both the username and the password must be provided")
|
||||||
|
flagsValid = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if !flagsValid {
|
if !flagsValid {
|
||||||
flag.Usage()
|
flag.Usage()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fail2ban-prometheus-exporter/auth"
|
||||||
"fail2ban-prometheus-exporter/cfg"
|
"fail2ban-prometheus-exporter/cfg"
|
||||||
"fail2ban-prometheus-exporter/collector/f2b"
|
"fail2ban-prometheus-exporter/collector/f2b"
|
||||||
"fail2ban-prometheus-exporter/collector/textfile"
|
"fail2ban-prometheus-exporter/collector/textfile"
|
||||||
|
@ -63,10 +64,13 @@ func main() {
|
||||||
textFileCollector := textfile.NewCollector(appSettings)
|
textFileCollector := textfile.NewCollector(appSettings)
|
||||||
prometheus.MustRegister(textFileCollector)
|
prometheus.MustRegister(textFileCollector)
|
||||||
|
|
||||||
http.HandleFunc("/", rootHtmlHandler)
|
http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider))
|
||||||
http.HandleFunc(metricsPath, func(w http.ResponseWriter, r *http.Request) {
|
http.HandleFunc(metricsPath, auth.BasicAuthMiddleware(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
metricHandler(w, r, textFileCollector)
|
metricHandler(w, r, textFileCollector)
|
||||||
})
|
},
|
||||||
|
appSettings.BasicAuthProvider,
|
||||||
|
))
|
||||||
log.Printf("metrics available at '%s'", metricsPath)
|
log.Printf("metrics available at '%s'", metricsPath)
|
||||||
|
|
||||||
svrErr := make(chan error)
|
svrErr := make(chan error)
|
||||||
|
|
Loading…
Reference in a new issue