diff --git a/cmdutil/service_flags.go b/cmdutil/service_flags.go index a8cdff66f..678a3d8f8 100644 --- a/cmdutil/service_flags.go +++ b/cmdutil/service_flags.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "strings" + "time" "unicode" "github.com/sirupsen/logrus" @@ -147,7 +148,7 @@ func (sf *ServiceFlags) Logger() *logging.Logger { } if discordWebhookURL := discord.GetWebhookURLFromEnv(); discordWebhookURL != "" { - hook := discord.NewHook(sf.Tag, discordWebhookURL) + hook := discord.NewHook(sf.Tag, discordWebhookURL, discord.WithLimit(10*time.Minute)) logging.AddHook(hook) } diff --git a/discord/discord.go b/discord/discord.go deleted file mode 100644 index 6a877db62..000000000 --- a/discord/discord.go +++ /dev/null @@ -1,31 +0,0 @@ -package discord - -import ( - "os" - "time" - - "github.com/kz/discordrus" - "github.com/sirupsen/logrus" -) - -const ( - webhookURLEnvName = "DISCORD_WEBHOOK_URL" -) - -// NewHook creates a new Discord hook. -func NewHook(tag, webHookURL string) logrus.Hook { - return discordrus.NewHook(webHookURL, logrus.ErrorLevel, discordOpts(tag)) -} - -func discordOpts(tag string) *discordrus.Opts { - return &discordrus.Opts{ - Username: tag, - TimestampFormat: time.RFC3339, - TimestampLocale: time.UTC, - } -} - -// GetWebhookURLFromEnv extracts Discord webhook URL from environment variables. -func GetWebhookURLFromEnv() string { - return os.Getenv(webhookURLEnvName) -} diff --git a/discord/hook.go b/discord/hook.go new file mode 100644 index 000000000..b4639ae73 --- /dev/null +++ b/discord/hook.go @@ -0,0 +1,81 @@ +package discord + +import ( + "os" + "time" + + "github.com/kz/discordrus" + "github.com/sirupsen/logrus" +) + +const ( + webhookURLEnvName = "DISCORD_WEBHOOK_URL" +) + +// Hook is a Discord logger hook. +type Hook struct { + logrus.Hook + limit time.Duration + timestamps map[string]time.Time +} + +// Option defines an option for Discord logger hook. +type Option func(*Hook) + +// WithLimit enables logger rate limiter with specified limit. +func WithLimit(limit time.Duration) Option { + return func(h *Hook) { + h.limit = limit + h.timestamps = make(map[string]time.Time) + } +} + +// NewHook returns a new Hook. +func NewHook(tag, webHookURL string, opts ...Option) logrus.Hook { + parent := discordrus.NewHook(webHookURL, logrus.ErrorLevel, discordOpts(tag)) + + hook := &Hook{ + Hook: parent, + } + + for _, opt := range opts { + opt(hook) + } + + return hook +} + +// Fire checks whether rate is fine and fires the underlying hook. +func (h *Hook) Fire(entry *logrus.Entry) error { + if h.shouldFire(entry) { + return h.Hook.Fire(entry) + } + + return nil +} + +func (h *Hook) shouldFire(entry *logrus.Entry) bool { + if h.limit != 0 && h.timestamps != nil { + v, ok := h.timestamps[entry.Message] + if ok && entry.Time.Sub(v) < h.limit { + return false + } + + h.timestamps[entry.Message] = entry.Time + } + + return true +} + +func discordOpts(tag string) *discordrus.Opts { + return &discordrus.Opts{ + Username: tag, + TimestampFormat: time.RFC3339, + TimestampLocale: time.UTC, + } +} + +// GetWebhookURLFromEnv extracts webhook URL from an environment variable. +func GetWebhookURLFromEnv() string { + return os.Getenv(webhookURLEnvName) +} diff --git a/discord/hook_test.go b/discord/hook_test.go new file mode 100644 index 000000000..e0fada3ab --- /dev/null +++ b/discord/hook_test.go @@ -0,0 +1,73 @@ +package discord + +import ( + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func TestHook_shouldFire(t *testing.T) { + hook := &Hook{ + limit: 1 * time.Millisecond, + timestamps: make(map[string]time.Time), + } + + ts := time.Now() + + tests := []struct { + name string + message string + timestamp time.Time + want bool + }{ + { + name: "Case 1", + message: "Message 1", + timestamp: ts, + want: true, + }, + { + name: "Case 2", + message: "Message 2", + timestamp: ts, + want: true, + }, + { + name: "Case 3", + message: "Message 1", + timestamp: ts, + want: false, + }, + { + name: "Case 4", + message: "Message 1", + timestamp: ts.Add(500 * time.Microsecond), + want: false, + }, + { + name: "Case 5", + message: "Message 1", + timestamp: ts.Add(1500 * time.Microsecond), + want: true, + }, + { + name: "Case 6", + message: "Message 1", + timestamp: ts.Add(2000 * time.Microsecond), + want: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + entry := &logrus.Entry{ + Time: tt.timestamp, + Message: tt.message, + } + + assert.Equal(t, tt.want, hook.shouldFire(entry)) + }) + } +}