github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/plugin/oauth2/loader.go (about)

     1  package oauth2
     2  
     3  import (
     4  	"github.com/rs/cors"
     5  	log "github.com/sirupsen/logrus"
     6  	"github.com/ulule/limiter/v3"
     7  	"github.com/ulule/limiter/v3/drivers/middleware/stdlib"
     8  	storeMemory "github.com/ulule/limiter/v3/drivers/store/memory"
     9  
    10  	"github.com/hellofresh/janus/pkg/proxy"
    11  	"github.com/hellofresh/janus/pkg/router"
    12  )
    13  
    14  // OAuthLoader handles the loading of the api specs
    15  type OAuthLoader struct {
    16  	register *proxy.Register
    17  }
    18  
    19  // NewOAuthLoader creates a new instance of the Loader
    20  func NewOAuthLoader(register *proxy.Register) *OAuthLoader {
    21  	return &OAuthLoader{register}
    22  }
    23  
    24  // LoadDefinitions loads all oauth servers from a data source
    25  func (m *OAuthLoader) LoadDefinitions(repo Repository) {
    26  	oAuthServers := m.getOAuthServers(repo)
    27  	m.RegisterOAuthServers(oAuthServers, repo)
    28  }
    29  
    30  // RegisterOAuthServers register many oauth servers
    31  func (m *OAuthLoader) RegisterOAuthServers(oauthServers []*Spec, repo Repository) {
    32  	log.Debug("Loading OAuth servers configurations")
    33  
    34  	for _, oauthServer := range oauthServers {
    35  		var mw []router.Constructor
    36  
    37  		logger := log.WithField("name", oauthServer.Name)
    38  		logger.Debug("Registering OAuth server")
    39  
    40  		corsHandler := cors.New(cors.Options{
    41  			AllowedOrigins:     oauthServer.CorsMeta.Domains,
    42  			AllowedMethods:     oauthServer.CorsMeta.Methods,
    43  			AllowedHeaders:     oauthServer.CorsMeta.RequestHeaders,
    44  			ExposedHeaders:     oauthServer.CorsMeta.ExposedHeaders,
    45  			OptionsPassthrough: oauthServer.CorsMeta.OptionsPassthrough,
    46  			AllowCredentials:   true,
    47  		}).Handler
    48  
    49  		mw = append(mw, corsHandler)
    50  
    51  		if oauthServer.RateLimit.Enabled {
    52  			rate, err := limiter.NewRateFromFormatted(oauthServer.RateLimit.Limit)
    53  			if err != nil {
    54  				logger.WithError(err).Error("Not able to create rate limit")
    55  			}
    56  
    57  			limiterStore := storeMemory.NewStore()
    58  			limiterInstance := limiter.New(limiterStore, rate)
    59  			rateLimitHandler := stdlib.NewMiddleware(limiterInstance).Handler
    60  
    61  			mw = append(mw, rateLimitHandler)
    62  		}
    63  
    64  		endpoints := map[*proxy.RouterDefinition][]router.Constructor{
    65  			proxy.NewRouterDefinition(oauthServer.Endpoints.Authorize):    mw,
    66  			proxy.NewRouterDefinition(oauthServer.Endpoints.Token):        append(mw, NewSecretMiddleware(oauthServer).Handler),
    67  			proxy.NewRouterDefinition(oauthServer.Endpoints.Introspect):   mw,
    68  			proxy.NewRouterDefinition(oauthServer.Endpoints.Revoke):       mw,
    69  			proxy.NewRouterDefinition(oauthServer.ClientEndpoints.Create): mw,
    70  			proxy.NewRouterDefinition(oauthServer.ClientEndpoints.Remove): mw,
    71  		}
    72  
    73  		m.registerRoutes(endpoints)
    74  		logger.Debug("OAuth server registered")
    75  	}
    76  
    77  	log.Debug("Done loading OAuth servers configurations")
    78  }
    79  
    80  func (m *OAuthLoader) getOAuthServers(repo Repository) []*Spec {
    81  	oauthServers, err := repo.FindAll()
    82  	if err != nil {
    83  		log.Panic(err)
    84  	}
    85  
    86  	var specs []*Spec
    87  	for _, oauthServer := range oauthServers {
    88  		spec := new(Spec)
    89  		spec.OAuth = oauthServer
    90  		manager, err := m.getManager(oauthServer)
    91  		if nil != err {
    92  			log.WithError(err).Error("Oauth definition is not well configured, skipping...")
    93  			continue
    94  		}
    95  		spec.Manager = manager
    96  		specs = append(specs, spec)
    97  	}
    98  
    99  	return specs
   100  }
   101  
   102  func (m *OAuthLoader) getManager(oauthServer *OAuth) (Manager, error) {
   103  	managerType, err := ParseType(oauthServer.TokenStrategy.Name)
   104  	if nil != err {
   105  		return nil, err
   106  	}
   107  
   108  	return NewManagerFactory(oauthServer).Build(managerType)
   109  }
   110  
   111  func (m *OAuthLoader) registerRoutes(endpoints map[*proxy.RouterDefinition][]router.Constructor) {
   112  	for endpoint, middleware := range endpoints {
   113  		if endpoint.Definition == nil || endpoint.Definition.ListenPath == "" {
   114  			log.Debug("Endpoint not registered")
   115  			continue
   116  		}
   117  
   118  		for _, mw := range middleware {
   119  			endpoint.AddMiddleware(mw)
   120  		}
   121  
   122  		l := log.WithField("listen_path", endpoint.ListenPath)
   123  		l.Debug("Registering OAuth endpoint")
   124  		if isValid, err := endpoint.Validate(); isValid && err == nil {
   125  			m.register.Add(endpoint)
   126  			l.Debug("Endpoint registered")
   127  		} else {
   128  			l.WithError(err).Error("Error when registering endpoint")
   129  		}
   130  	}
   131  }