Expose context to cache functions

This commit is contained in:
Thorben Günther 2024-11-09 14:13:33 +01:00
parent ea58dec539
commit 182444df3a
No known key found for this signature in database
GPG key ID: 415CD778D8C5AFED
5 changed files with 24 additions and 19 deletions

9
cache/cache.go vendored
View file

@ -2,6 +2,7 @@
package cache package cache
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
@ -10,18 +11,18 @@ import (
// Cache is the interface that describes a cache for ntfy-alertmanager. // Cache is the interface that describes a cache for ntfy-alertmanager.
type Cache interface { type Cache interface {
Set(fingerprint string, status string) error Set(ctx context.Context, fingerprint string, status string) error
Contains(fingerprint string, status string) (bool, error) Contains(ctx context.Context, fingerprint string, status string) (bool, error)
Cleanup() Cleanup()
} }
// NewCache reads the cache configuration cfg and creates the cache. // NewCache reads the cache configuration cfg and creates the cache.
func NewCache(cfg config.CacheConfig) (Cache, error) { func NewCache(ctx context.Context, cfg config.CacheConfig) (Cache, error) {
switch strings.ToLower(cfg.Type) { switch strings.ToLower(cfg.Type) {
case "memory": case "memory":
return NewMemoryCache(cfg.Duration), nil return NewMemoryCache(cfg.Duration), nil
case "redis": case "redis":
return NewRedisCache(cfg.RedisURL, cfg.Duration) return NewRedisCache(ctx, cfg.RedisURL, cfg.Duration)
case "disabled": case "disabled":
return NewDisabledCache() return NewDisabledCache()
default: default:

6
cache/disabled.go vendored
View file

@ -1,5 +1,7 @@
package cache package cache
import "context"
// DisabledCache is the disabled cache. // DisabledCache is the disabled cache.
type DisabledCache struct{} type DisabledCache struct{}
@ -10,12 +12,12 @@ func NewDisabledCache() (Cache, error) {
} }
// Set is an empty function to implement the interface. // Set is an empty function to implement the interface.
func (c *DisabledCache) Set(_ string, _ string) error { func (c *DisabledCache) Set(_ context.Context, _ string, _ string) error {
return nil return nil
} }
// Contains is an empty function to implement the interface. // Contains is an empty function to implement the interface.
func (c *DisabledCache) Contains(_ string, _ string) (bool, error) { func (c *DisabledCache) Contains(_ context.Context, _ string, _ string) (bool, error) {
return false, nil return false, nil
} }

5
cache/memory.go vendored
View file

@ -1,6 +1,7 @@
package cache package cache
import ( import (
"context"
"sync" "sync"
"time" "time"
) )
@ -27,7 +28,7 @@ func NewMemoryCache(d time.Duration) Cache {
} }
// Set saves an alert in the cache. // Set saves an alert in the cache.
func (c *MemoryCache) Set(fingerprint string, status string) error { func (c *MemoryCache) Set(_ context.Context, fingerprint string, status string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
alert := new(cachedAlert) alert := new(cachedAlert)
@ -40,7 +41,7 @@ func (c *MemoryCache) Set(fingerprint string, status string) error {
// Contains checks if an alert with a given fingerprint is in the cache // Contains checks if an alert with a given fingerprint is in the cache
// and if the status matches. // and if the status matches.
func (c *MemoryCache) Contains(fingerprint string, status string) (bool, error) { func (c *MemoryCache) Contains(_ context.Context, fingerprint string, status string) (bool, error) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
alert, ok := c.alerts[fingerprint] alert, ok := c.alerts[fingerprint]

12
cache/redis.go vendored
View file

@ -16,7 +16,7 @@ type RedisCache struct {
} }
// NewRedisCache creates a new redis cache/client. // NewRedisCache creates a new redis cache/client.
func NewRedisCache(redisURL string, d time.Duration) (Cache, error) { func NewRedisCache(ctx context.Context, redisURL string, d time.Duration) (Cache, error) {
c := new(RedisCache) c := new(RedisCache)
ropts, err := redis.ParseURL(redisURL) ropts, err := redis.ParseURL(redisURL)
if err != nil { if err != nil {
@ -24,7 +24,7 @@ func NewRedisCache(redisURL string, d time.Duration) (Cache, error) {
} }
rdb := redis.NewClient(ropts) rdb := redis.NewClient(ropts)
ctx, cancel := context.WithTimeout(context.TODO(), redisTimeout) ctx, cancel := context.WithTimeout(ctx, redisTimeout)
defer cancel() defer cancel()
err = rdb.Ping(ctx).Err() err = rdb.Ping(ctx).Err()
@ -38,8 +38,8 @@ func NewRedisCache(redisURL string, d time.Duration) (Cache, error) {
} }
// Set saves an alert in the cache. // Set saves an alert in the cache.
func (c *RedisCache) Set(fingerprint string, status string) error { func (c *RedisCache) Set(ctx context.Context, fingerprint string, status string) error {
ctx, cancel := context.WithTimeout(context.TODO(), redisTimeout) ctx, cancel := context.WithTimeout(ctx, redisTimeout)
defer cancel() defer cancel()
return c.client.SetEx(ctx, fingerprint, status, c.duration).Err() return c.client.SetEx(ctx, fingerprint, status, c.duration).Err()
@ -47,8 +47,8 @@ func (c *RedisCache) Set(fingerprint string, status string) error {
// Contains checks if an alert with a given fingerprint is in the cache // Contains checks if an alert with a given fingerprint is in the cache
// and if the status matches. // and if the status matches.
func (c *RedisCache) Contains(fingerprint string, status string) (bool, error) { func (c *RedisCache) Contains(ctx context.Context, fingerprint string, status string) (bool, error) {
ctx, cancel := context.WithTimeout(context.TODO(), redisTimeout) ctx, cancel := context.WithTimeout(ctx, redisTimeout)
defer cancel() defer cancel()
val, err := c.client.Get(ctx, fingerprint).Result() val, err := c.client.Get(ctx, fingerprint).Result()

11
main.go
View file

@ -79,10 +79,10 @@ type ntfyError struct {
Error string `json:"error"` Error string `json:"error"`
} }
func (br *bridge) singleAlertNotifications(p *payload) []*notification { func (br *bridge) singleAlertNotifications(ctx context.Context, p *payload) []*notification {
var notifications []*notification var notifications []*notification
for _, alert := range p.Alerts { for _, alert := range p.Alerts {
contains, err := br.cache.Contains(alert.Fingerprint, alert.Status) contains, err := br.cache.Contains(ctx, alert.Fingerprint, alert.Status)
if err != nil { if err != nil {
br.logger.Error("Failed to lookup alert in cache", br.logger.Error("Failed to lookup alert in cache",
slog.String("fingerprint", alert.Fingerprint), slog.String("fingerprint", alert.Fingerprint),
@ -427,6 +427,7 @@ func (br *bridge) publish(n *notification, topicParam string) error {
} }
func (br *bridge) handleWebhooks(w http.ResponseWriter, r *http.Request) { func (br *bridge) handleWebhooks(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
logger := br.logger.With(slog.String("handler", "/")) logger := br.logger.With(slog.String("handler", "/"))
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@ -456,14 +457,14 @@ func (br *bridge) handleWebhooks(w http.ResponseWriter, r *http.Request) {
slog.Any("payload", event)) slog.Any("payload", event))
if br.cfg.AlertMode == config.Single { if br.cfg.AlertMode == config.Single {
notifications := br.singleAlertNotifications(&event) notifications := br.singleAlertNotifications(ctx, &event)
for _, n := range notifications { for _, n := range notifications {
err := br.publish(n, topicParam) err := br.publish(n, topicParam)
if err != nil { if err != nil {
logger.Error("Failed to publish notification", logger.Error("Failed to publish notification",
slog.String("error", err.Error())) slog.String("error", err.Error()))
} else { } else {
if err := br.cache.Set(n.fingerprint, n.status); err != nil { if err := br.cache.Set(ctx, n.fingerprint, n.status); err != nil {
logger.Error("Failed to cache alert", logger.Error("Failed to cache alert",
slog.String("fingerprint", n.fingerprint), slog.String("fingerprint", n.fingerprint),
slog.String("error", err.Error())) slog.String("error", err.Error()))
@ -574,7 +575,7 @@ func main() {
client := &httpClient{&http.Client{Timeout: time.Second * 3}} client := &httpClient{&http.Client{Timeout: time.Second * 3}}
c, err := cache.NewCache(cfg.Cache) c, err := cache.NewCache(ctx, cfg.Cache)
if err != nil { if err != nil {
logger.Error("Failed to create cache", logger.Error("Failed to create cache",
slog.String("error", err.Error())) slog.String("error", err.Error()))