github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/plugin/oauth2/loader.go (about) 1 package oauth2 2 3 import ( 4 "github.com/rs/cors" 5 log "github.com/sirupsen/logrus" 6 "github.com/ulule/limiter/v3" 7 "github.com/ulule/limiter/v3/drivers/middleware/stdlib" 8 storeMemory "github.com/ulule/limiter/v3/drivers/store/memory" 9 10 "github.com/hellofresh/janus/pkg/proxy" 11 "github.com/hellofresh/janus/pkg/router" 12 ) 13 14 // OAuthLoader handles the loading of the api specs 15 type OAuthLoader struct { 16 register *proxy.Register 17 } 18 19 // NewOAuthLoader creates a new instance of the Loader 20 func NewOAuthLoader(register *proxy.Register) *OAuthLoader { 21 return &OAuthLoader{register} 22 } 23 24 // LoadDefinitions loads all oauth servers from a data source 25 func (m *OAuthLoader) LoadDefinitions(repo Repository) { 26 oAuthServers := m.getOAuthServers(repo) 27 m.RegisterOAuthServers(oAuthServers, repo) 28 } 29 30 // RegisterOAuthServers register many oauth servers 31 func (m *OAuthLoader) RegisterOAuthServers(oauthServers []*Spec, repo Repository) { 32 log.Debug("Loading OAuth servers configurations") 33 34 for _, oauthServer := range oauthServers { 35 var mw []router.Constructor 36 37 logger := log.WithField("name", oauthServer.Name) 38 logger.Debug("Registering OAuth server") 39 40 corsHandler := cors.New(cors.Options{ 41 AllowedOrigins: oauthServer.CorsMeta.Domains, 42 AllowedMethods: oauthServer.CorsMeta.Methods, 43 AllowedHeaders: oauthServer.CorsMeta.RequestHeaders, 44 ExposedHeaders: oauthServer.CorsMeta.ExposedHeaders, 45 OptionsPassthrough: oauthServer.CorsMeta.OptionsPassthrough, 46 AllowCredentials: true, 47 }).Handler 48 49 mw = append(mw, corsHandler) 50 51 if oauthServer.RateLimit.Enabled { 52 rate, err := limiter.NewRateFromFormatted(oauthServer.RateLimit.Limit) 53 if err != nil { 54 logger.WithError(err).Error("Not able to create rate limit") 55 } 56 57 limiterStore := storeMemory.NewStore() 58 limiterInstance := limiter.New(limiterStore, rate) 59 rateLimitHandler := stdlib.NewMiddleware(limiterInstance).Handler 60 61 mw = append(mw, rateLimitHandler) 62 } 63 64 endpoints := map[*proxy.RouterDefinition][]router.Constructor{ 65 proxy.NewRouterDefinition(oauthServer.Endpoints.Authorize): mw, 66 proxy.NewRouterDefinition(oauthServer.Endpoints.Token): append(mw, NewSecretMiddleware(oauthServer).Handler), 67 proxy.NewRouterDefinition(oauthServer.Endpoints.Introspect): mw, 68 proxy.NewRouterDefinition(oauthServer.Endpoints.Revoke): mw, 69 proxy.NewRouterDefinition(oauthServer.ClientEndpoints.Create): mw, 70 proxy.NewRouterDefinition(oauthServer.ClientEndpoints.Remove): mw, 71 } 72 73 m.registerRoutes(endpoints) 74 logger.Debug("OAuth server registered") 75 } 76 77 log.Debug("Done loading OAuth servers configurations") 78 } 79 80 func (m *OAuthLoader) getOAuthServers(repo Repository) []*Spec { 81 oauthServers, err := repo.FindAll() 82 if err != nil { 83 log.Panic(err) 84 } 85 86 var specs []*Spec 87 for _, oauthServer := range oauthServers { 88 spec := new(Spec) 89 spec.OAuth = oauthServer 90 manager, err := m.getManager(oauthServer) 91 if nil != err { 92 log.WithError(err).Error("Oauth definition is not well configured, skipping...") 93 continue 94 } 95 spec.Manager = manager 96 specs = append(specs, spec) 97 } 98 99 return specs 100 } 101 102 func (m *OAuthLoader) getManager(oauthServer *OAuth) (Manager, error) { 103 managerType, err := ParseType(oauthServer.TokenStrategy.Name) 104 if nil != err { 105 return nil, err 106 } 107 108 return NewManagerFactory(oauthServer).Build(managerType) 109 } 110 111 func (m *OAuthLoader) registerRoutes(endpoints map[*proxy.RouterDefinition][]router.Constructor) { 112 for endpoint, middleware := range endpoints { 113 if endpoint.Definition == nil || endpoint.Definition.ListenPath == "" { 114 log.Debug("Endpoint not registered") 115 continue 116 } 117 118 for _, mw := range middleware { 119 endpoint.AddMiddleware(mw) 120 } 121 122 l := log.WithField("listen_path", endpoint.ListenPath) 123 l.Debug("Registering OAuth endpoint") 124 if isValid, err := endpoint.Validate(); isValid && err == nil { 125 m.register.Add(endpoint) 126 l.Debug("Endpoint registered") 127 } else { 128 l.WithError(err).Error("Error when registering endpoint") 129 } 130 } 131 }