diff --git a/cache/cache.go b/cache/cache.go index fecc16d..61fc02c 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -2,6 +2,7 @@ package cache import ( + "context" "fmt" "strings" @@ -10,18 +11,18 @@ import ( // Cache is the interface that describes a cache for ntfy-alertmanager. type Cache interface { - Set(fingerprint string, status string) error - Contains(fingerprint string, status string) (bool, error) + Set(ctx context.Context, fingerprint string, status string) error + Contains(ctx context.Context, fingerprint string, status string) (bool, error) Cleanup() } // NewCache reads the cache configuration cfg and creates the cache. -func NewCache(cfg config.CacheConfig) (Cache, error) { +func NewCache(ctx context.Context, cfg config.CacheConfig) (Cache, error) { switch strings.ToLower(cfg.Type) { case "memory": return NewMemoryCache(cfg.Duration), nil case "redis": - return NewRedisCache(cfg.RedisURL, cfg.Duration) + return NewRedisCache(ctx, cfg.RedisURL, cfg.Duration) case "disabled": return NewDisabledCache() default: diff --git a/cache/disabled.go b/cache/disabled.go index 91d9a8f..a667fe9 100644 --- a/cache/disabled.go +++ b/cache/disabled.go @@ -1,5 +1,7 @@ package cache +import "context" + // DisabledCache is the disabled cache. type DisabledCache struct{} @@ -10,12 +12,12 @@ func NewDisabledCache() (Cache, error) { } // Set is an empty function to implement the interface. -func (c *DisabledCache) Set(_ string, _ string) error { +func (c *DisabledCache) Set(_ context.Context, _ string, _ string) error { return nil } // Contains is an empty function to implement the interface. -func (c *DisabledCache) Contains(_ string, _ string) (bool, error) { +func (c *DisabledCache) Contains(_ context.Context, _ string, _ string) (bool, error) { return false, nil } diff --git a/cache/memory.go b/cache/memory.go index 8f11794..bd27ffb 100644 --- a/cache/memory.go +++ b/cache/memory.go @@ -1,6 +1,7 @@ package cache import ( + "context" "sync" "time" ) @@ -27,7 +28,7 @@ func NewMemoryCache(d time.Duration) Cache { } // Set saves an alert in the cache. -func (c *MemoryCache) Set(fingerprint string, status string) error { +func (c *MemoryCache) Set(_ context.Context, fingerprint string, status string) error { c.mu.Lock() defer c.mu.Unlock() alert := new(cachedAlert) @@ -40,7 +41,7 @@ func (c *MemoryCache) Set(fingerprint string, status string) error { // Contains checks if an alert with a given fingerprint is in the cache // and if the status matches. -func (c *MemoryCache) Contains(fingerprint string, status string) (bool, error) { +func (c *MemoryCache) Contains(_ context.Context, fingerprint string, status string) (bool, error) { c.mu.Lock() defer c.mu.Unlock() alert, ok := c.alerts[fingerprint] diff --git a/cache/redis.go b/cache/redis.go index 8911311..67bf2db 100644 --- a/cache/redis.go +++ b/cache/redis.go @@ -16,7 +16,7 @@ type RedisCache struct { } // NewRedisCache creates a new redis cache/client. -func NewRedisCache(redisURL string, d time.Duration) (Cache, error) { +func NewRedisCache(ctx context.Context, redisURL string, d time.Duration) (Cache, error) { c := new(RedisCache) ropts, err := redis.ParseURL(redisURL) if err != nil { @@ -24,7 +24,7 @@ func NewRedisCache(redisURL string, d time.Duration) (Cache, error) { } rdb := redis.NewClient(ropts) - ctx, cancel := context.WithTimeout(context.TODO(), redisTimeout) + ctx, cancel := context.WithTimeout(ctx, redisTimeout) defer cancel() err = rdb.Ping(ctx).Err() @@ -38,8 +38,8 @@ func NewRedisCache(redisURL string, d time.Duration) (Cache, error) { } // Set saves an alert in the cache. -func (c *RedisCache) Set(fingerprint string, status string) error { - ctx, cancel := context.WithTimeout(context.TODO(), redisTimeout) +func (c *RedisCache) Set(ctx context.Context, fingerprint string, status string) error { + ctx, cancel := context.WithTimeout(ctx, redisTimeout) defer cancel() return c.client.SetEx(ctx, fingerprint, status, c.duration).Err() @@ -47,8 +47,8 @@ func (c *RedisCache) Set(fingerprint string, status string) error { // Contains checks if an alert with a given fingerprint is in the cache // and if the status matches. -func (c *RedisCache) Contains(fingerprint string, status string) (bool, error) { - ctx, cancel := context.WithTimeout(context.TODO(), redisTimeout) +func (c *RedisCache) Contains(ctx context.Context, fingerprint string, status string) (bool, error) { + ctx, cancel := context.WithTimeout(ctx, redisTimeout) defer cancel() val, err := c.client.Get(ctx, fingerprint).Result() diff --git a/main.go b/main.go index b1837a1..e6c61ec 100644 --- a/main.go +++ b/main.go @@ -79,10 +79,10 @@ type ntfyError struct { Error string `json:"error"` } -func (br *bridge) singleAlertNotifications(p *payload) []*notification { +func (br *bridge) singleAlertNotifications(ctx context.Context, p *payload) []*notification { var notifications []*notification for _, alert := range p.Alerts { - contains, err := br.cache.Contains(alert.Fingerprint, alert.Status) + contains, err := br.cache.Contains(ctx, alert.Fingerprint, alert.Status) if err != nil { br.logger.Error("Failed to lookup alert in cache", slog.String("fingerprint", alert.Fingerprint), @@ -427,6 +427,7 @@ func (br *bridge) publish(n *notification, topicParam string) error { } func (br *bridge) handleWebhooks(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() logger := br.logger.With(slog.String("handler", "/")) if r.Method != http.MethodPost { @@ -456,14 +457,14 @@ func (br *bridge) handleWebhooks(w http.ResponseWriter, r *http.Request) { slog.Any("payload", event)) if br.cfg.AlertMode == config.Single { - notifications := br.singleAlertNotifications(&event) + notifications := br.singleAlertNotifications(ctx, &event) for _, n := range notifications { err := br.publish(n, topicParam) if err != nil { logger.Error("Failed to publish notification", slog.String("error", err.Error())) } else { - if err := br.cache.Set(n.fingerprint, n.status); err != nil { + if err := br.cache.Set(ctx, n.fingerprint, n.status); err != nil { logger.Error("Failed to cache alert", slog.String("fingerprint", n.fingerprint), slog.String("error", err.Error())) @@ -574,7 +575,7 @@ func main() { client := &httpClient{&http.Client{Timeout: time.Second * 3}} - c, err := cache.NewCache(cfg.Cache) + c, err := cache.NewCache(ctx, cfg.Cache) if err != nil { logger.Error("Failed to create cache", slog.String("error", err.Error()))