refactor: rewrite auth handler code (!89)
* Rewrite the code handling basic auth to make it easier to extend for other types of auth. * The behaviour of the existing code is maintained. * No changes to how basic auth is configured from a user's perspective. https://gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/-/merge_requests/89
This commit is contained in:
parent
3215fe5f4c
commit
3cff8ccd64
15 changed files with 233 additions and 190 deletions
29
auth/basic.go
Normal file
29
auth/basic.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewBasicAuthProvider(username, password string) AuthProvider {
|
||||||
|
return &basicAuthProvider{
|
||||||
|
hashedAuth: encodeBasicAuth(username, password),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type basicAuthProvider struct {
|
||||||
|
hashedAuth string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *basicAuthProvider) IsAllowed(request *http.Request) bool {
|
||||||
|
username, password, ok := request.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
requestAuth := encodeBasicAuth(username, password)
|
||||||
|
return p.hashedAuth == requestAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeBasicAuth(username, password string) string {
|
||||||
|
return HashString(fmt.Sprintf("%s:%s", username, password))
|
||||||
|
}
|
53
auth/basic_test.go
Normal file
53
auth/basic_test.go
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_GIVEN_BasicAuthSet_WHEN_CallingIsAllowedWithCorrectCreds_THEN_TrueReturned(t *testing.T) {
|
||||||
|
// assemble
|
||||||
|
username := "u1"
|
||||||
|
password := HashString("abc")
|
||||||
|
request := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
request.SetBasicAuth(username, password)
|
||||||
|
provider := NewBasicAuthProvider(username, password)
|
||||||
|
|
||||||
|
// act
|
||||||
|
result := provider.IsAllowed(request)
|
||||||
|
|
||||||
|
// assert
|
||||||
|
if !result {
|
||||||
|
t.Errorf("expected request to be allowed, but failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_GIVEN_BasicAuthSet_WHEN_CallingIsAllowedWithoutCreds_THEN_FalseReturned(t *testing.T) {
|
||||||
|
// assemble
|
||||||
|
request := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
provider := NewBasicAuthProvider("u1", "p1")
|
||||||
|
|
||||||
|
// act
|
||||||
|
result := provider.IsAllowed(request)
|
||||||
|
|
||||||
|
// assert
|
||||||
|
if result {
|
||||||
|
t.Errorf("expected request to be denied, but was allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_GIVEN_BasicAuthSet_WHEN_CallingIsAllowedWithWrongCreds_THEN_FalseReturned(t *testing.T) {
|
||||||
|
// assemble
|
||||||
|
request := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
request.SetBasicAuth("wrong", "pw")
|
||||||
|
provider := NewBasicAuthProvider("u1", "p1")
|
||||||
|
|
||||||
|
// act
|
||||||
|
result := provider.IsAllowed(request)
|
||||||
|
|
||||||
|
// assert
|
||||||
|
if result {
|
||||||
|
t.Errorf("expected request to be denied, but was allowed")
|
||||||
|
}
|
||||||
|
}
|
14
auth/empty.go
Normal file
14
auth/empty.go
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
func NewEmptyAuthProvider() AuthProvider {
|
||||||
|
return &emptyAuthProvider{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type emptyAuthProvider struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *emptyAuthProvider) IsAllowed(request *http.Request) bool {
|
||||||
|
return true
|
||||||
|
}
|
36
auth/empty_test.go
Normal file
36
auth/empty_test.go
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_GIVEN_EmptyAuth_WHEN_CallingIsAllowedWithoutAuth_THEN_TrueReturned(t *testing.T) {
|
||||||
|
// assemble
|
||||||
|
request := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
provider := NewEmptyAuthProvider()
|
||||||
|
|
||||||
|
// act
|
||||||
|
response := provider.IsAllowed(request)
|
||||||
|
|
||||||
|
// assert
|
||||||
|
if !response {
|
||||||
|
t.Errorf("expected request to be allowed, but failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_GIVEN_EmptyAuth_WHEN_CallingIsAllowedWithAuth_THEN_TrueReturned(t *testing.T) {
|
||||||
|
// assemble
|
||||||
|
request := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
request.SetBasicAuth("user", "pass")
|
||||||
|
provider := NewEmptyAuthProvider()
|
||||||
|
|
||||||
|
// act
|
||||||
|
response := provider.IsAllowed(request)
|
||||||
|
|
||||||
|
// assert
|
||||||
|
if !response {
|
||||||
|
t.Errorf("expected request to be allowed, but failed")
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Hash(data []byte) []byte {
|
func hash(data []byte) []byte {
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return []byte{}
|
return []byte{}
|
||||||
}
|
}
|
||||||
|
@ -14,5 +14,5 @@ func Hash(data []byte) []byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
func HashString(data string) string {
|
func HashString(data string) string {
|
||||||
return hex.EncodeToString(Hash([]byte(data)))
|
return hex.EncodeToString(hash([]byte(data)))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,31 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BasicAuthProvider interface {
|
|
||||||
Enabled() bool
|
|
||||||
DoesBasicAuthMatch(username, password string) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, basicAuthProvider BasicAuthProvider) http.HandlerFunc {
|
|
||||||
if basicAuthProvider.Enabled() {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if doesBasicAuthMatch(r, basicAuthProvider) {
|
|
||||||
handlerFunc.ServeHTTP(w, r)
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return handlerFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func doesBasicAuthMatch(r *http.Request, basicAuthProvider BasicAuthProvider) bool {
|
|
||||||
rawUsername, rawPassword, ok := r.BasicAuth()
|
|
||||||
if ok {
|
|
||||||
return basicAuthProvider.DoesBasicAuthMatch(rawUsername, rawPassword)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
|
@ -1,58 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type testAuthProvider struct {
|
|
||||||
enabled bool
|
|
||||||
match bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p testAuthProvider) Enabled() bool {
|
|
||||||
return p.enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p testAuthProvider) DoesBasicAuthMatch(username, password string) bool {
|
|
||||||
return p.match
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestRequest() *http.Request {
|
|
||||||
return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func executeBasicAuthMiddlewareTest(t *testing.T, authEnabled bool, authMatches bool, expectedCode int, expectedCallCount int) {
|
|
||||||
callCount := 0
|
|
||||||
testHandler := func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
callCount++
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := BasicAuthMiddleware(testHandler, testAuthProvider{enabled: authEnabled, match: authMatches})
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
request := newTestRequest()
|
|
||||||
if authEnabled {
|
|
||||||
request.SetBasicAuth("test", "test")
|
|
||||||
}
|
|
||||||
handler.ServeHTTP(recorder, request)
|
|
||||||
|
|
||||||
if recorder.Code != expectedCode {
|
|
||||||
t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode)
|
|
||||||
}
|
|
||||||
if callCount != expectedCallCount {
|
|
||||||
t.Errorf("callCount = %v, want %v", callCount, expectedCallCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_GIVEN_DisabledBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
|
|
||||||
executeBasicAuthMiddlewareTest(t, false, false, http.StatusOK, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithCorrectCredentials_THEN_RequestProcessed(t *testing.T) {
|
|
||||||
executeBasicAuthMiddlewareTest(t, true, true, http.StatusOK, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_GIVEN_EnabledBasicAuth_WHEN_MethodCalledWithIncorrectCredentials_THEN_RequestRejected(t *testing.T) {
|
|
||||||
executeBasicAuthMiddlewareTest(t, true, false, http.StatusUnauthorized, 0)
|
|
||||||
}
|
|
9
auth/provider.go
Normal file
9
auth/provider.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthProvider interface {
|
||||||
|
IsAllowed(*http.Request) bool
|
||||||
|
}
|
|
@ -1,25 +0,0 @@
|
||||||
package cfg
|
|
||||||
|
|
||||||
import "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
|
|
||||||
|
|
||||||
type hashedBasicAuth struct {
|
|
||||||
username string
|
|
||||||
password string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHashedBasicAuth(rawUsername, rawPassword string) *hashedBasicAuth {
|
|
||||||
return &hashedBasicAuth{
|
|
||||||
username: auth.HashString(rawUsername),
|
|
||||||
password: auth.HashString(rawPassword),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *hashedBasicAuth) Enabled() bool {
|
|
||||||
return len(p.username) > 0 && len(p.password) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *hashedBasicAuth) DoesBasicAuthMatch(rawUsername, rawPassword string) bool {
|
|
||||||
username := auth.HashString(rawUsername)
|
|
||||||
password := auth.HashString(rawPassword)
|
|
||||||
return username == p.username && password == p.password
|
|
||||||
}
|
|
|
@ -1,60 +0,0 @@
|
||||||
package cfg
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func Test_hashedBasicAuth_DoesBasicAuthMatch(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
username string
|
|
||||||
password string
|
|
||||||
}
|
|
||||||
type fields struct {
|
|
||||||
username string
|
|
||||||
password string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"Happy test #1", fields{username: "1234", password: "1234"}, args{username: "1234", password: "1234"}, true},
|
|
||||||
{"Happy test #2", fields{username: "test", password: "1234"}, args{username: "test", password: "1234"}, true},
|
|
||||||
{"Happy test #3", fields{username: "TEST", password: "1234"}, args{username: "TEST", password: "1234"}, true},
|
|
||||||
{"Non match #1", fields{username: "test", password: "1234"}, args{username: "1234", password: "1234"}, false},
|
|
||||||
{"Non match #2", fields{username: "1234", password: "test"}, args{username: "1234", password: "1234"}, false},
|
|
||||||
{"Non match #3", fields{username: "1234", password: "test"}, args{username: "1234", password: "TEST"}, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password)
|
|
||||||
if got := basicAuth.DoesBasicAuthMatch(tt.args.username, tt.args.password); got != tt.want {
|
|
||||||
t.Errorf("DoesBasicAuthMatch() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_hashedBasicAuth_Enabled(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
username string
|
|
||||||
password string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"Both blank", fields{username: "", password: ""}, false},
|
|
||||||
{"Single blank #1", fields{username: "test", password: ""}, false},
|
|
||||||
{"Single blank #1", fields{username: "", password: "test"}, false},
|
|
||||||
{"Both populated", fields{username: "test", password: "test"}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
basicAuth := newHashedBasicAuth(tt.fields.username, tt.fields.password)
|
|
||||||
if got := basicAuth.Enabled(); got != tt.want {
|
|
||||||
t.Errorf("Enabled() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
15
cfg/cfg.go
15
cfg/cfg.go
|
@ -2,9 +2,11 @@ package cfg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/alecthomas/kong"
|
"github.com/alecthomas/kong"
|
||||||
|
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
var cliStruct struct {
|
var cliStruct struct {
|
||||||
|
@ -36,11 +38,22 @@ func Parse() *AppSettings {
|
||||||
Fail2BanSocketPath: cliStruct.F2bSocketPath,
|
Fail2BanSocketPath: cliStruct.F2bSocketPath,
|
||||||
FileCollectorPath: cliStruct.TextFileExporterPath,
|
FileCollectorPath: cliStruct.TextFileExporterPath,
|
||||||
ExitOnSocketConnError: cliStruct.ExitOnSocketError,
|
ExitOnSocketConnError: cliStruct.ExitOnSocketError,
|
||||||
BasicAuthProvider: newHashedBasicAuth(cliStruct.BasicAuthUser, cliStruct.BasicAuthPass),
|
AuthProvider: createAuthProvider(),
|
||||||
}
|
}
|
||||||
return settings
|
return settings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createAuthProvider() auth.AuthProvider {
|
||||||
|
username := cliStruct.BasicAuthUser
|
||||||
|
password := cliStruct.BasicAuthPass
|
||||||
|
|
||||||
|
if len(username) == 0 && len(password) == 0 {
|
||||||
|
return auth.NewEmptyAuthProvider()
|
||||||
|
}
|
||||||
|
log.Print("basic auth enabled")
|
||||||
|
return auth.NewBasicAuthProvider(username, password)
|
||||||
|
}
|
||||||
|
|
||||||
func validateFlags(cliCtx *kong.Context) {
|
func validateFlags(cliCtx *kong.Context) {
|
||||||
var flagsValid = true
|
var flagsValid = true
|
||||||
var messages = []string{}
|
var messages = []string{}
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
package cfg
|
package cfg
|
||||||
|
|
||||||
|
import "gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
|
||||||
|
|
||||||
type AppSettings struct {
|
type AppSettings struct {
|
||||||
VersionMode bool
|
VersionMode bool
|
||||||
MetricsAddress string
|
MetricsAddress string
|
||||||
Fail2BanSocketPath string
|
Fail2BanSocketPath string
|
||||||
FileCollectorPath string
|
FileCollectorPath string
|
||||||
BasicAuthProvider *hashedBasicAuth
|
AuthProvider auth.AuthProvider
|
||||||
ExitOnSocketConnError bool
|
ExitOnSocketConnError bool
|
||||||
}
|
}
|
||||||
|
|
22
exporter.go
22
exporter.go
|
@ -2,17 +2,18 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
||||||
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
|
|
||||||
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/cfg"
|
|
||||||
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/f2b"
|
|
||||||
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/textfile"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/cfg"
|
||||||
|
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/f2b"
|
||||||
|
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/collector/textfile"
|
||||||
|
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -66,17 +67,14 @@ func main() {
|
||||||
textFileCollector := textfile.NewCollector(appSettings)
|
textFileCollector := textfile.NewCollector(appSettings)
|
||||||
prometheus.MustRegister(textFileCollector)
|
prometheus.MustRegister(textFileCollector)
|
||||||
|
|
||||||
http.HandleFunc("/", auth.BasicAuthMiddleware(rootHtmlHandler, appSettings.BasicAuthProvider))
|
http.HandleFunc("/", server.BasicAuthMiddleware(rootHtmlHandler, appSettings.AuthProvider))
|
||||||
http.HandleFunc(metricsPath, auth.BasicAuthMiddleware(
|
http.HandleFunc(metricsPath, server.BasicAuthMiddleware(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
metricHandler(w, r, textFileCollector)
|
metricHandler(w, r, textFileCollector)
|
||||||
},
|
},
|
||||||
appSettings.BasicAuthProvider,
|
appSettings.AuthProvider,
|
||||||
))
|
))
|
||||||
log.Printf("metrics available at '%s'", metricsPath)
|
log.Printf("metrics available at '%s'", metricsPath)
|
||||||
if appSettings.BasicAuthProvider.Enabled() {
|
|
||||||
log.Printf("basic auth enabled")
|
|
||||||
}
|
|
||||||
|
|
||||||
svrErr := make(chan error)
|
svrErr := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
|
|
17
server/middleware.go
Normal file
17
server/middleware.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"gitlab.com/hectorjsmith/fail2ban-prometheus-exporter/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BasicAuthMiddleware(handlerFunc http.HandlerFunc, authProvider auth.AuthProvider) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if authProvider.IsAllowed(r) {
|
||||||
|
handlerFunc.ServeHTTP(w, r)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
46
server/middleware_test.go
Normal file
46
server/middleware_test.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testAuthProvider struct {
|
||||||
|
match bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p testAuthProvider) IsAllowed(request *http.Request) bool {
|
||||||
|
return p.match
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestRequest() *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func executeBasicAuthMiddlewareTest(t *testing.T, authMatches bool, expectedCode int, expectedCallCount int) {
|
||||||
|
callCount := 0
|
||||||
|
testHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
callCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := BasicAuthMiddleware(testHandler, testAuthProvider{match: authMatches})
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
request := newTestRequest()
|
||||||
|
handler.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
if recorder.Code != expectedCode {
|
||||||
|
t.Errorf("statusCode = %v, want %v", recorder.Code, expectedCode)
|
||||||
|
}
|
||||||
|
if callCount != expectedCallCount {
|
||||||
|
t.Errorf("callCount = %v, want %v", callCount, expectedCallCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_GIVEN_MatchingBasicAuth_WHEN_MethodCalled_THEN_RequestProcessed(t *testing.T) {
|
||||||
|
executeBasicAuthMiddlewareTest(t, true, http.StatusOK, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_GIVEN_NonMatchingBasicAuth_WHEN_MethodCalled_THEN_RequestRejected(t *testing.T) {
|
||||||
|
executeBasicAuthMiddlewareTest(t, false, http.StatusUnauthorized, 0)
|
||||||
|
}
|
Loading…
Reference in a new issue