github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/plugin/oauth2/setup.go (about)

     1  package oauth2
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net/url"
     8  	"time"
     9  
    10  	"github.com/asaskevich/govalidator"
    11  	log "github.com/sirupsen/logrus"
    12  	"go.mongodb.org/mongo-driver/mongo"
    13  	"go.mongodb.org/mongo-driver/mongo/options"
    14  	"go.mongodb.org/mongo-driver/x/bsonx"
    15  
    16  	"github.com/hellofresh/janus/pkg/config"
    17  	"github.com/hellofresh/janus/pkg/jwt"
    18  	"github.com/hellofresh/janus/pkg/plugin"
    19  	"github.com/hellofresh/janus/pkg/proxy"
    20  	"github.com/hellofresh/janus/pkg/router"
    21  )
    22  
    23  const (
    24  	mongodb = "mongodb"
    25  	file    = "file"
    26  	cassandra = "cassandra"
    27  
    28  	mongoIdxTimeout = 10 * time.Second
    29  )
    30  
    31  var (
    32  	repo        Repository
    33  	loader      *OAuthLoader
    34  	adminRouter router.Router
    35  )
    36  
    37  func init() {
    38  	plugin.RegisterEventHook(plugin.StartupEvent, onStartup)
    39  	plugin.RegisterEventHook(plugin.ReloadEvent, onReload)
    40  	plugin.RegisterEventHook(plugin.AdminAPIStartupEvent, onAdminAPIStartup)
    41  	plugin.RegisterPlugin("oauth2", plugin.Plugin{
    42  		Action:   setupOAuth2,
    43  		Validate: validateConfig,
    44  	})
    45  }
    46  
    47  // Config represents the oauth configuration
    48  type Config struct {
    49  	ServerName string `json:"server_name"`
    50  }
    51  
    52  func onAdminAPIStartup(event interface{}) error {
    53  	e, ok := event.(plugin.OnAdminAPIStartup)
    54  	if !ok {
    55  		return errors.New("could not convert event to admin startup type")
    56  	}
    57  
    58  	adminRouter = e.Router
    59  	return nil
    60  }
    61  
    62  func onReload(event interface{}) error {
    63  	_, ok := event.(plugin.OnReload)
    64  	if !ok {
    65  		return errors.New("could not convert event to reload type")
    66  	}
    67  
    68  	loader.LoadDefinitions(repo)
    69  
    70  	return nil
    71  }
    72  
    73  func onStartup(event interface{}) error {
    74  	e, ok := event.(plugin.OnStartup)
    75  	if !ok {
    76  		return errors.New("could not convert event to startup type")
    77  	}
    78  
    79  	cfg := e.Config.Database
    80  	dsnURL, err := url.Parse(cfg.DSN)
    81  	if err != nil {
    82  		return err
    83  	}
    84  
    85  	switch dsnURL.Scheme {
    86  	case mongodb:
    87  		repo, err = NewMongoRepository(e.MongoDB)
    88  		if err != nil {
    89  			return fmt.Errorf("could not create a mongodb repository for oauth servers: %w", err)
    90  		}
    91  
    92  		ctx, cancel := context.WithTimeout(context.Background(), mongoIdxTimeout)
    93  		defer cancel()
    94  
    95  		if _, err := e.MongoDB.Collection(collectionName).Indexes().CreateOne(
    96  			ctx,
    97  			mongo.IndexModel{
    98  				Keys: bsonx.Doc{
    99  					{Key: "name", Value: bsonx.Int32(1)},
   100  				},
   101  				Options: options.Index().SetUnique(true).SetBackground(true).SetSparse(true),
   102  			},
   103  		); err != nil {
   104  			return fmt.Errorf("failed to create indexes for oauth servers repository: %w", err)
   105  		}
   106  	case cassandra:
   107  		repo, err = NewCassandraRepository(e.Cassandra)
   108  		if err != nil {
   109  			log.Errorf("error creating new cassandra repo")
   110  			return err
   111  		}
   112  	case file:
   113  		authPath := fmt.Sprintf("%s/auth", dsnURL.Path)
   114  		log.WithField("path", authPath).Debug("Trying to load Auth configuration files")
   115  
   116  		repo, err = NewFileSystemRepository(authPath)
   117  		if err != nil {
   118  			return fmt.Errorf("could not create a file based repository for the oauth servers: %w", err)
   119  		}
   120  
   121  	default:
   122  		return errors.New("the selected scheme is not supported to load OAuth servers")
   123  	}
   124  
   125  	loadOAuthEndpoints(adminRouter, repo, e.Config.Web.Credentials)
   126  	loader = NewOAuthLoader(e.Register)
   127  	loader.LoadDefinitions(repo)
   128  
   129  	return nil
   130  }
   131  
   132  func setupOAuth2(def *proxy.RouterDefinition, rawConfig plugin.Config) error {
   133  	var cfg Config
   134  	err := plugin.Decode(rawConfig, &cfg)
   135  	if err != nil {
   136  		return err
   137  	}
   138  
   139  	oauthServer, err := repo.FindByName(cfg.ServerName)
   140  	if nil != err {
   141  		return err
   142  	}
   143  
   144  	manager, err := getManager(oauthServer, cfg.ServerName)
   145  	if nil != err {
   146  		log.WithError(err).Error("OAuth Configuration for this API is incorrect, skipping...")
   147  		return err
   148  	}
   149  
   150  	signingMethods, err := oauthServer.TokenStrategy.GetJWTSigningMethods()
   151  	if err != nil {
   152  		return err
   153  	}
   154  
   155  	def.AddMiddleware(NewKeyExistsMiddleware(manager))
   156  	def.AddMiddleware(NewRevokeRulesMiddleware(jwt.NewParser(jwt.NewParserConfig(oauthServer.TokenStrategy.Leeway, signingMethods...)), oauthServer.AccessRules))
   157  
   158  	return nil
   159  }
   160  
   161  func validateConfig(rawConfig plugin.Config) (bool, error) {
   162  	var cfg Config
   163  	err := plugin.Decode(rawConfig, &cfg)
   164  	if err != nil {
   165  		return false, err
   166  	}
   167  
   168  	return govalidator.ValidateStruct(cfg)
   169  }
   170  
   171  func getManager(oauthServer *OAuth, oAuthServerName string) (Manager, error) {
   172  	managerType, err := ParseType(oauthServer.TokenStrategy.Name)
   173  	if nil != err {
   174  		return nil, err
   175  	}
   176  
   177  	return NewManagerFactory(oauthServer).Build(managerType)
   178  }
   179  
   180  // loadOAuthEndpoints register api endpoints
   181  func loadOAuthEndpoints(router router.Router, repo Repository, cred config.Credentials) {
   182  	log.Debug("Loading OAuth Endpoints")
   183  
   184  	guard := jwt.NewGuard(cred)
   185  	oAuthHandler := NewController(repo)
   186  	oauthGroup := router.Group("/oauth/servers")
   187  	oauthGroup.Use(jwt.NewMiddleware(guard).Handler)
   188  	{
   189  		oauthGroup.GET("/", oAuthHandler.Get())
   190  		oauthGroup.GET("/{name}", oAuthHandler.GetBy())
   191  		oauthGroup.POST("/", oAuthHandler.Post())
   192  		oauthGroup.PUT("/{name}", oAuthHandler.PutBy())
   193  		oauthGroup.DELETE("/{name}", oAuthHandler.DeleteBy())
   194  	}
   195  }