Merge branch 'refactor/rewrite-auth-handler-code' into 'main'

refactor: rewrite auth handler code

See merge request hectorjsmith/fail2ban-prometheus-exporter!89
This commit is contained in:
Hector 2023-06-21 10:31:33 +00:00
commit 41b05f7e16
15 changed files with 233 additions and 190 deletions

29
auth/basic.go Normal file
View file

@ -0,0 +1,29 @@
package auth
import (
"fmt"
"net/http"
)
func NewBasicAuthProvider(username, password string) AuthProvider {
return &basicAuthProvider{
hashedAuth: encodeBasicAuth(username, password),
}
}
type basicAuthProvider struct {
hashedAuth string
}
func (p *basicAuthProvider) IsAllowed(request *http.Request) bool {
username, password, ok := request.BasicAuth()
if !ok {
return false
}
requestAuth := encodeBasicAuth(username, password)
return p.hashedAuth == requestAuth
}
func encodeBasicAuth(username, password string) string {
return HashString(fmt.Sprintf("%s:%s", username, password))
}

53
auth/basic_test.go Normal file
View file

@ -0,0 +1,53 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
)
func Test_GIVEN_BasicAuthSet_WHEN_CallingIsAllowedWithCorrectCreds_THEN_TrueReturned(t *testing.T) {
// assemble
username := "u1"
password := HashString("abc")
request := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
request.SetBasicAuth(username, password)
provider := NewBasicAuthProvider(username, password)
// act
result := provider.IsAllowed(request)
// assert
if !result {
t.Errorf("expected request to be allowed, but failed")
}
}
func Test_GIVEN_BasicAuthSet_WHEN_CallingIsAllowedWithoutCreds_THEN_FalseReturned(t *testing.T) {
// assemble
request := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
provider := NewBasicAuthProvider("u1", "p1")
// act
result := provider.IsAllowed(request)
// assert
if result {
t.Errorf("expected request to be denied, but was allowed")
}
}
func Test_GIVEN_BasicAuthSet_WHEN_CallingIsAllowedWithWrongCreds_THEN_FalseReturned(t *testing.T) {
// assemble
request := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
request.SetBasicAuth("wrong", "pw")
provider := NewBasicAuthProvider("u1", "p1")
// act
result := provider.IsAllowed(request)
// assert
if result {
t.Errorf("expected request to be denied, but was allowed")
}
}

14
auth/empty.go Normal file
View file

@ -0,0 +1,14 @@
package auth
import "net/http"
func NewEmptyAuthProvider() AuthProvider {
return &emptyAuthProvider{}
}
type emptyAuthProvider struct {
}
func (p *emptyAuthProvider) IsAllowed(request *http.Request) bool {
return true
}

36
auth/empty_test.go Normal file
View file

@ -0,0 +1,36 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
)
func Test_GIVEN_EmptyAuth_WHEN_CallingIsAllowedWithoutAuth_THEN_TrueReturned(t *testing.T) {
// assemble
request := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
provider := NewEmptyAuthProvider()
// act
response := provider.IsAllowed(request)
// assert
if !response {
t.Errorf("expected request to be allowed, but failed")
}
}
func Test_GIVEN_EmptyAuth_WHEN_CallingIsAllowedWithAuth_THEN_TrueReturned(t *testing.T) {
// assemble
request := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
request.SetBasicAuth("user", "pass")
provider := NewEmptyAuthProvider()
// act
response := provider.IsAllowed(request)
// assert
if !response {
t.Errorf("expected request to be allowed, but failed")
}
}

View file

@ -5,7 +5,7 @@ import (
"encoding/hex" "encoding/hex"
) )
func Hash(data []byte) []byte { func hash(data []byte) []byte {
if len(data) == 0 { if len(data) == 0 {
return []byte{} return []byte{}
} }
@ -14,5 +14,5 @@ func Hash(data []byte) []byte {
} }
func HashString(data string) string { func HashString(data string) string {
return hex.EncodeToString(Hash([]byte(data))) return hex.EncodeToString(hash([]byte(data)))
} }

View file

@ -1,31 +0,0 @@
package auth
import (
"net/http"
)
type BasicAuthProvider interface {
Enabled() bool
DoesBasicAuthMatch(username, password string) bool
}
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc {
if basicAuthProvider.Enabled() {
return func(w http.ResponseWriter, r *http.Request) {
if doesBasicAuthMatch(r, basicAuthProvider) {
handlerFunc.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusUnauthorized)
}
}
}
return handlerFunc
}
func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
rawUsername, rawPassword, ok := r.BasicAuth()
if ok {
return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword)
}
return false
}

View file

@ -1,58 +0,0 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
)
type testAuthProvider struct {
enabled bool
match bool
}
func (p testAuthProvider) Enabled() bool {
return p.enabled
}
func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool {
return p.match
}
func newTestRequest() *http.Request {
return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
}
func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) {
callCount := 0
testHandler := func(w http.ResponseWriter, r *http.Request) {
callCount++
}
handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches})
recorder := httptest.NewRecorder()
request := newTestRequest()
if authEnabled {
request.SetBasicAuth("test", "test")
}
handler.ServeHTTP(recorder, request)
if recorder.Code != expectedCode {
t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode)
}
if callCount != expectedCallCount {
t.Errorf("callCount = %v, want %v", callCount, expectedCallCount)
}
}
func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1)
}
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) {
executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1)
}
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) {
executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0)
}

9
auth/provider.go Normal file
View file

@ -0,0 +1,9 @@
package auth
import (
"net/http"
)
type AuthProvider interface {
IsAllowed(*http.Request) bool
}

View file

@ -1,25 +0,0 @@
package cfg
import "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
type hashedBasicAuth struct {
username string
password string
}
func newHashedBasicAuth(rawUsername, rawPassword string) *hashedBasicAuth {
return &hashedBasicAuth{
username: auth.HashString(rawUsername),
password: auth.HashString(rawPassword),
}
}
func (p *hashedBasicAuth) Enabled() bool {
return len(p.username) > 0 && len(p.password) > 0
}
func (p *hashedBasicAuth) DoesBasicAuthMatch(rawUsername, rawPassword string) bool {
username := auth.HashString(rawUsername)
password := auth.HashString(rawPassword)
return username == p.username && password == p.password
}

View file

@ -1,60 +0,0 @@
package cfg
import "testing"
func Test_hashedBasicAuth_DoesBasicAuthMatch(t *testing.T) {
type args struct {
username string
password string
}
type fields struct {
username string
password string
}
tests := []struct {
name string
fields fields
args args
want bool
}{
{"Happy test #1", fields{username: "1234", password: "1234"}, args{username: "1234", password: "1234"}, true},
{"Happy test #2", fields{username: "test", password: "1234"}, args{username: "test", password: "1234"}, true},
{"Happy test #3", fields{username: "TEST", password: "1234"}, args{username: "TEST", password: "1234"}, true},
{"Non match #1", fields{username: "test", password: "1234"}, args{username: "1234", password: "1234"}, false},
{"Non match #2", fields{username: "1234", password: "test"}, args{username: "1234", password: "1234"}, false},
{"Non match #3", fields{username: "1234", password: "test"}, args{username: "1234", password: "TEST"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password)
if got := basicAuth.DoesBasicAuthMatch(tt.args.username, tt.args.password); got != tt.want {
t.Errorf("DoesBasicAuthMatch() = %v, want %v", got, tt.want)
}
})
}
}
func Test_hashedBasicAuth_Enabled(t *testing.T) {
type fields struct {
username string
password string
}
tests := []struct {
name string
fields fields
want bool
}{
{"Both blank", fields{username: "", password: ""}, false},
{"Single blank #1", fields{username: "test", password: ""}, false},
{"Single blank #1", fields{username: "", password: "test"}, false},
{"Both populated", fields{username: "test", password: "test"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password)
if got := basicAuth.Enabled(); got != tt.want {
t.Errorf("Enabled() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -2,9 +2,11 @@ package cfg
import ( import (
"fmt" "fmt"
"log"
"os" "os"
"github.com/alecthomas/kong" "github.com/alecthomas/kong"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
) )
var cliStruct struct { var cliStruct struct {
@ -36,11 +38,22 @@ func Parse() *AppSettings {
Fail2BanSocketPath: cliStruct.F2bSocketPath, Fail2BanSocketPath: cliStruct.F2bSocketPath,
FileCollectorPath: cliStruct.TextFileExporterPath, FileCollectorPath: cliStruct.TextFileExporterPath,
ExitOnSocketConnError: cliStruct.ExitOnSocketError, ExitOnSocketConnError: cliStruct.ExitOnSocketError,
BasicAuthProvider: newHashedBasicAuth(cliStruct.BasicAuthUser, cliStruct.BasicAuthPass), AuthProvider: createAuthProvider(),
} }
return settings return settings
} }
func createAuthProvider() auth.AuthProvider {
username := cliStruct.BasicAuthUser
password := cliStruct.BasicAuthPass
if len(username) == 0 && len(password) == 0 {
return auth.NewEmptyAuthProvider()
}
log.Print("basic auth enabled")
return auth.NewBasicAuthProvider(username, password)
}
func validateFlags(cliCtx *kong.Context) { func validateFlags(cliCtx *kong.Context) {
var flagsValid = true var flagsValid = true
var messages = []string{} var messages = []string{}

View file

@ -1,10 +1,12 @@
package cfg package cfg
import "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
type AppSettings struct { type AppSettings struct {
VersionMode bool VersionMode bool
MetricsAddress string MetricsAddress string
Fail2BanSocketPath string Fail2BanSocketPath string
FileCollectorPath string FileCollectorPath string
BasicAuthProvider *hashedBasicAuth AuthProvider auth.AuthProvider
ExitOnSocketConnError bool ExitOnSocketConnError bool
} }

View file

@ -2,17 +2,18 @@ package main
import ( import (
"fmt" "fmt"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/cfg"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/f2b"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/textfile"
"log" "log"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/cfg"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/f2b"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/textfile"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/server"
) )
const ( const (
@ -66,17 +67,14 @@ func main() {
textFileCollector := textfile.NewCollector(appSettings) textFileCollector := textfile.NewCollector(appSettings)
prometheus.MustRegister(textFileCollector) prometheus.MustRegister(textFileCollector)
http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider)) http.HandleFunc("/", server.BasicAuthMiddleware(rootHtmlHandler, appSettings.AuthProvider))
http.HandleFunc(metricsPath, auth.BasicAuthMiddleware( http.HandleFunc(metricsPath, server.BasicAuthMiddleware(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
metricHandler(w, r, textFileCollector) metricHandler(w, r, textFileCollector)
}, },
appSettings.BasicAuthProvider, appSettings.AuthProvider,
)) ))
log.Printf("metrics available at '%s'", metricsPath) log.Printf("metrics available at '%s'", metricsPath)
if appSettings.BasicAuthProvider.Enabled() {
log.Printf("basic auth enabled")
}
svrErr := make(chan error) svrErr := make(chan error)
go func() { go func() {

17
server/middleware.go Normal file
View file

@ -0,0 +1,17 @@
package server
import (
"net/http"
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
)
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, authProvider auth.AuthProvider) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if authProvider.IsAllowed(r) {
handlerFunc.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusUnauthorized)
}
}
}

46
server/middleware_test.go Normal file
View file

@ -0,0 +1,46 @@
package server
import (
"net/http"
"net/http/httptest"
"testing"
)
type testAuthProvider struct {
match bool
}
func (p testAuthProvider) IsAllowed(request *http.Request) bool {
return p.match
}
func newTestRequest() *http.Request {
return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
}
func executeBasicAuthMiddlewareTest(t *testing.T, authMatches bool, expectedCode int, expectedCallCount int) {
callCount := 0
testHandler := func(w http.ResponseWriter, r *http.Request) {
callCount++
}
handler := BasicAuthMiddleware(testHandler, testAuthProvider{match: authMatches})
recorder := httptest.NewRecorder()
request := newTestRequest()
handler.ServeHTTP(recorder, request)
if recorder.Code != expectedCode {
t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode)
}
if callCount != expectedCallCount {
t.Errorf("callCount = %v, want %v", callCount, expectedCallCount)
}
}
func Test_GIVEN_MatchingBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
executeBasicAuthMiddlewareTest(t, true, http.StatusOK, 1)
}
func Test_GIVEN_NonMatchingBasicAuth_WHEN_MethodCalled_THEN_RequestRejected(t *testing.T) {
executeBasicAuthMiddlewareTest(t, false, http.StatusUnauthorized, 0)
}