github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/plugin/rate/setup.go (about) 1 package rate 2 3 import ( 4 "net/http" 5 "time" 6 7 "github.com/asaskevich/govalidator" 8 "github.com/go-redis/redis/v7" 9 "github.com/hellofresh/stats-go/client" 10 "github.com/ulule/limiter/v3" 11 "github.com/ulule/limiter/v3/drivers/middleware/stdlib" 12 storeMemory "github.com/ulule/limiter/v3/drivers/store/memory" 13 storeRedis "github.com/ulule/limiter/v3/drivers/store/redis" 14 15 "github.com/hellofresh/janus/pkg/errors" 16 "github.com/hellofresh/janus/pkg/plugin" 17 "github.com/hellofresh/janus/pkg/proxy" 18 ) 19 20 var ( 21 statsClient client.Client 22 // ErrInvalidPolicy is used when an invalid policy was provided 23 ErrInvalidPolicy = errors.New(http.StatusBadRequest, "policy is not supported") 24 ) 25 26 const ( 27 // DefaultPrefix is the default prefix to use for the key in the store. 28 DefaultPrefix = "limiter" 29 ) 30 31 // Config represents a rate limit config 32 type Config struct { 33 Limit string `json:"limit"` 34 Policy string `json:"policy"` 35 RedisConfig redisConfig `json:"redis"` 36 TrustForwardHeaders bool `json:"trust_forward_headers"` 37 } 38 39 type redisConfig struct { 40 DSN string `json:"dsn"` 41 Prefix string `json:"prefix"` 42 } 43 44 func init() { 45 plugin.RegisterEventHook(plugin.StartupEvent, onStartup) 46 plugin.RegisterPlugin("rate_limit", plugin.Plugin{ 47 Action: setupRateLimit, 48 Validate: validateConfig, 49 }) 50 } 51 52 func onStartup(event interface{}) error { 53 e, ok := event.(plugin.OnStartup) 54 if !ok { 55 return errors.New(http.StatusInternalServerError, "Could not convert event to startup type") 56 } 57 58 statsClient = e.StatsClient 59 return nil 60 } 61 62 func validateConfig(rawConfig plugin.Config) (bool, error) { 63 var config Config 64 err := plugin.Decode(rawConfig, &config) 65 if err != nil { 66 return false, err 67 } 68 69 return govalidator.ValidateStruct(config) 70 } 71 72 func setupRateLimit(def *proxy.RouterDefinition, rawConfig plugin.Config) error { 73 var config Config 74 err := plugin.Decode(rawConfig, &config) 75 if err != nil { 76 return err 77 } 78 79 rate, err := limiter.NewRateFromFormatted(config.Limit) 80 if err != nil { 81 return err 82 } 83 84 limiterStore, err := getLimiterStore(config.Policy, config.RedisConfig) 85 if err != nil { 86 return err 87 } 88 89 limiterInstance := limiter.New(limiterStore, rate, limiter.WithTrustForwardHeader(config.TrustForwardHeaders)) 90 def.AddMiddleware(NewRateLimitLogger(limiterInstance, statsClient, config.TrustForwardHeaders)) 91 def.AddMiddleware(stdlib.NewMiddleware(limiterInstance).Handler) 92 93 return nil 94 } 95 96 func getLimiterStore(policy string, config redisConfig) (limiter.Store, error) { 97 switch policy { 98 case "redis": 99 option, err := redis.ParseURL(config.DSN) 100 if err != nil { 101 return nil, err 102 } 103 option.PoolSize = 3 104 option.IdleTimeout = 240 * time.Second 105 redisClient := redis.NewClient(option) 106 107 if config.Prefix == "" { 108 config.Prefix = DefaultPrefix 109 } 110 111 return storeRedis.NewStoreWithOptions(redisClient, limiter.StoreOptions{ 112 Prefix: config.Prefix, 113 MaxRetry: limiter.DefaultMaxRetry, 114 }) 115 116 case "local": 117 return storeMemory.NewStore(), nil 118 119 default: 120 return nil, ErrInvalidPolicy 121 } 122 }